Skip to content

Commit 3fb386a

Browse files
author
donglaiw
committed
add valid mask sampling
1 parent 12f43ea commit 3fb386a

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

connectomics/data/dataset/dataset_volume.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ def __init__(self,
137137
# handle partially labeled volume
138138
self.valid_mask = valid_mask
139139
self.valid_ratio = valid_ratio
140+
# precompute valid region
141+
self.valid_pos = [None] * len(self.valid_mask)
142+
if self.valid_mask is not None:
143+
for i, x in enumerate(self.valid_mask):
144+
if x is not None:
145+
self.valid_pos[i] = get_valid_pos(x, sample_volume_size, valid_ratio)
140146

141147
if self.mode in ['val', 'test']: # for validation and test
142148
self.sample_size_test = [
@@ -231,13 +237,17 @@ def _get_pos_train(self, vol_size):
231237
# np.random: same seed
232238
pos = [0, 0, 0, 0]
233239
# pick a dataset
234-
did = self._index_to_dataset(random.randint(0, self.sample_num_a-1))
240+
did = self._index_to_dataset(random.randint(0, self.sample_num_a))
235241
pos[0] = did
236242
# pick a position
237-
tmp_size = count_volume(
238-
self.volume_size[did], vol_size, self.sample_stride)
239-
tmp_pos = [random.randint(0, tmp_size[x]-1) * self.sample_stride[x]
240-
for x in range(len(tmp_size))]
243+
# all regions are valid
244+
if self.valid_pos[did] is None:
245+
tmp_size = count_volume(
246+
self.volume_size[did], vol_size, self.sample_stride)
247+
tmp_pos = [random.randint(0, tmp_size[x]) * self.sample_stride[x]
248+
for x in range(len(tmp_size))]
249+
else:
250+
tmp_pos = self.valid_pos[did][random.randint(0, self.valid_pos[did].shape[0])]
241251

242252
pos[1:] = tmp_pos
243253
return pos

0 commit comments

Comments
 (0)