@@ -146,6 +146,7 @@ def _load_label_condition(name, mode: str, image_only_test: bool):
146146def _get_input (cfg ,
147147 mode = 'train' ,
148148 rank = None ,
149+ preload_data = None ,
149150 dir_name_init : Optional [list ] = None ,
150151 img_name_init : Optional [list ] = None ,
151152 min_size : Optional [tuple ] = None ,
@@ -162,54 +163,61 @@ def _validate_shape(cfg, image, mask, i):
162163 assert image [i ].shape == mask [i ].shape [- 3 :]
163164
164165 assert mode in ['train' , 'val' , 'test' ]
165- dir_path = cfg .DATASET .INPUT_PATH
166- if dir_name_init is not None :
167- dir_name = dir_name_init
168- else :
169- dir_name = _get_file_list (dir_path )
170-
166+
167+ pad_mode = cfg .DATASET .PAD_MODE
171168 if mode == 'val' :
172- img_name = cfg .DATASET .VAL_IMAGE_NAME
173- label_name = cfg .DATASET .VAL_LABEL_NAME
174- valid_mask_name = cfg .DATASET .VAL_VALID_MASK_NAME
175169 pad_size = cfg .DATASET .VAL_PAD_SIZE
176170 else :
177- img_name = cfg .DATASET .IMAGE_NAME
178- label_name = cfg .DATASET .LABEL_NAME
179- valid_mask_name = cfg .DATASET .VALID_MASK_NAME
180- pad_size = cfg .DATASET .PAD_SIZE
181- assert not all ([elem == None for elem in [img_name , label_name ]]), \
182- "At least one of img_name and label_name should not be None!"
183-
171+ pad_size = cfg .DATASET .PAD_SIZE
184172 volume , label , valid_mask = None , None , None
185- if img_name_init is not None :
186- img_name = img_name_init
187-
188- if img_name is not None :
189- img_name = _get_file_list (img_name , prefix = dir_path )
190- img_name = _make_path_list (cfg , dir_name , img_name , rank )
191- volume = [None ] * len (img_name )
192- print (rank , len (img_name ), list (map (os .path .basename , img_name )))
193-
194- if _load_label_condition (label_name , mode , image_only_test ):
195- label_name = _get_file_list (label_name , prefix = dir_path )
196- label_name = _make_path_list (cfg , dir_name , label_name , rank )
197- label = [None ]* len (label_name )
198- print (rank , len (label_name ), list (map (os .path .basename , label_name )))
199-
200- if _load_label_condition (valid_mask_name , mode , image_only_test ):
201- valid_mask_name = _get_file_list (valid_mask_name , prefix = dir_path )
202- valid_mask_name = _make_path_list (cfg , dir_name , valid_mask_name , rank )
203- valid_mask = [None ]* len (valid_mask_name )
204173
205- pad_mode = cfg .DATASET .PAD_MODE
206- read_fn = readvol if not cfg .DATASET .LOAD_2D else readimg_as_vol
207- num_vols = len (img_name ) if img_name is not None else len (label_name )
174+ if preload_data is not None :
175+ volume = preload_data
176+ num_vols = len (preload_data )
177+ else :
178+ dir_path = cfg .DATASET .INPUT_PATH
179+ if dir_name_init is not None :
180+ dir_name = dir_name_init
181+ else :
182+ dir_name = _get_file_list (dir_path )
183+
184+ if mode == 'val' :
185+ img_name = cfg .DATASET .VAL_IMAGE_NAME
186+ label_name = cfg .DATASET .VAL_LABEL_NAME
187+ valid_mask_name = cfg .DATASET .VAL_VALID_MASK_NAME
188+ else :
189+ img_name = cfg .DATASET .IMAGE_NAME
190+ label_name = cfg .DATASET .LABEL_NAME
191+ valid_mask_name = cfg .DATASET .VALID_MASK_NAME
192+ assert not all ([elem == None for elem in [img_name , label_name ]]), \
193+ "At least one of img_name and label_name should not be None!"
194+
195+ if img_name_init is not None :
196+ img_name = img_name_init
197+
198+ if img_name is not None :
199+ img_name = _get_file_list (img_name , prefix = dir_path )
200+ img_name = _make_path_list (cfg , dir_name , img_name , rank )
201+ volume = [None ] * len (img_name )
202+ print (rank , len (img_name ), list (map (os .path .basename , img_name )))
203+
204+ if _load_label_condition (label_name , mode , image_only_test ):
205+ label_name = _get_file_list (label_name , prefix = dir_path )
206+ label_name = _make_path_list (cfg , dir_name , label_name , rank )
207+ label = [None ]* len (label_name )
208+ print (rank , len (label_name ), list (map (os .path .basename , label_name )))
209+
210+ if _load_label_condition (valid_mask_name , mode , image_only_test ):
211+ valid_mask_name = _get_file_list (valid_mask_name , prefix = dir_path )
212+ valid_mask_name = _make_path_list (cfg , dir_name , valid_mask_name , rank )
213+ valid_mask = [None ]* len (valid_mask_name )
214+ read_fn = readvol if not cfg .DATASET .LOAD_2D else readimg_as_vol
215+ num_vols = len (img_name ) if img_name is not None else len (label_name )
208216
209217 for i in range (num_vols ):
210218 if volume is not None :
211-
212- volume [i ] = read_fn (img_name [i ], drop_channel = cfg .DATASET .DROP_CHANNEL )
219+ if preload_data is None :
220+ volume [i ] = read_fn (img_name [i ], drop_channel = cfg .DATASET .DROP_CHANNEL )
213221 print (f"volume shape (original): { volume [i ].shape } " )
214222 if cfg .DATASET .NORMALIZE_RANGE :
215223 volume [i ] = normalize_range (volume [i ])
@@ -257,8 +265,7 @@ def get_dataset(cfg,
257265 dataset_options = {},
258266 dir_name_init : Optional [list ] = None ,
259267 img_name_init : Optional [list ] = None ,
260- tensorstore_data = None ,
261- tensorstore_coord : Optional [list ] = None ):
268+ preload_data = None ):
262269 r"""Prepare dataset for training and inference.
263270 """
264271 assert mode in ['train' , 'val' , 'test' ]
@@ -340,14 +347,12 @@ def _make_json_path(path, name):
340347 ** shared_kwargs )
341348
342349 else : # build VolumeDataset or VolumeDatasetMultiSeg
343- if tensorstore_data is None :
350+ if preload_data is None :
344351 volume , label , valid_mask = _get_input (
345352 cfg , mode , rank , dir_name_init , img_name_init , min_size = sample_volume_size )
346353 else :
347- volume = [tensorstore_data [coord [0 ]:coord [1 ],coord [2 ]:coord [3 ],coord [4 ]:coord [5 ]].read ().result ().transpose () \
348- for coord in tensorstore_coord ]
349- label = None
350- valid_mask = None
354+ volume , label , valid_mask = _get_input (
355+ cfg , mode , rank , preload_data = preload_data , min_size = sample_volume_size )
351356
352357 if cfg .MODEL .TARGET_OPT_MULTISEG_SPLIT is not None :
353358 shared_kwargs ['multiseg_split' ] = cfg .MODEL .TARGET_OPT_MULTISEG_SPLIT
0 commit comments