@@ -137,6 +137,12 @@ def __init__(self,
137137 # handle partially labeled volume
138138 self .valid_mask = valid_mask
139139 self .valid_ratio = valid_ratio
140+ # precompute valid region
141+ self .valid_pos = [None ] * len (self .valid_mask )
142+ if self .valid_mask is not None :
143+ for i , x in enumerate (self .valid_mask ):
144+ if x is not None :
145+ self .valid_pos [i ] = get_valid_pos (x , sample_volume_size , valid_ratio )
140146
141147 if self .mode in ['val' , 'test' ]: # for validation and test
142148 self .sample_size_test = [
@@ -231,13 +237,17 @@ def _get_pos_train(self, vol_size):
231237 # np.random: same seed
232238 pos = [0 , 0 , 0 , 0 ]
233239 # pick a dataset
234- did = self ._index_to_dataset (random .randint (0 , self .sample_num_a - 1 ))
240+ did = self ._index_to_dataset (random .randint (0 , self .sample_num_a ))
235241 pos [0 ] = did
236242 # pick a position
237- tmp_size = count_volume (
238- self .volume_size [did ], vol_size , self .sample_stride )
239- tmp_pos = [random .randint (0 , tmp_size [x ]- 1 ) * self .sample_stride [x ]
240- for x in range (len (tmp_size ))]
243+ # all regions are valid
244+ if self .valid_pos [did ] is None :
245+ tmp_size = count_volume (
246+ self .volume_size [did ], vol_size , self .sample_stride )
247+ tmp_pos = [random .randint (0 , tmp_size [x ]) * self .sample_stride [x ]
248+ for x in range (len (tmp_size ))]
249+ else :
250+ tmp_pos = self .valid_pos [did ][random .randint (0 , self .valid_pos [did ].shape [0 ])]
241251
242252 pos [1 :] = tmp_pos
243253 return pos
0 commit comments