Skip to content

Commit 394504f

Browse files
author
donglaiw
committed
integrate tensorstore into the preload data pipeline. o/w not doing the same padding approach
1 parent b08e8cc commit 394504f

File tree

2 files changed

+67
-55
lines changed

2 files changed

+67
-55
lines changed

connectomics/data/dataset/build.py

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _load_label_condition(name, mode: str, image_only_test: bool):
146146
def _get_input(cfg,
147147
mode='train',
148148
rank=None,
149+
preload_data=None,
149150
dir_name_init: Optional[list] = None,
150151
img_name_init: Optional[list] = None,
151152
min_size: Optional[tuple] = None,
@@ -162,54 +163,61 @@ def _validate_shape(cfg, image, mask, i):
162163
assert image[i].shape == mask[i].shape[-3:]
163164

164165
assert mode in ['train', 'val', 'test']
165-
dir_path = cfg.DATASET.INPUT_PATH
166-
if dir_name_init is not None:
167-
dir_name = dir_name_init
168-
else:
169-
dir_name = _get_file_list(dir_path)
170-
166+
167+
pad_mode = cfg.DATASET.PAD_MODE
171168
if mode == 'val':
172-
img_name = cfg.DATASET.VAL_IMAGE_NAME
173-
label_name = cfg.DATASET.VAL_LABEL_NAME
174-
valid_mask_name = cfg.DATASET.VAL_VALID_MASK_NAME
175169
pad_size = cfg.DATASET.VAL_PAD_SIZE
176170
else:
177-
img_name = cfg.DATASET.IMAGE_NAME
178-
label_name = cfg.DATASET.LABEL_NAME
179-
valid_mask_name = cfg.DATASET.VALID_MASK_NAME
180-
pad_size = cfg.DATASET.PAD_SIZE
181-
assert not all([elem == None for elem in [img_name, label_name]]), \
182-
"At least one of img_name and label_name should not be None!"
183-
171+
pad_size = cfg.DATASET.PAD_SIZE
184172
volume, label, valid_mask = None, None, None
185-
if img_name_init is not None:
186-
img_name = img_name_init
187-
188-
if img_name is not None:
189-
img_name = _get_file_list(img_name, prefix=dir_path)
190-
img_name = _make_path_list(cfg, dir_name, img_name, rank)
191-
volume = [None] * len(img_name)
192-
print(rank, len(img_name), list(map(os.path.basename, img_name)))
193-
194-
if _load_label_condition(label_name, mode, image_only_test):
195-
label_name = _get_file_list(label_name, prefix=dir_path)
196-
label_name = _make_path_list(cfg, dir_name, label_name, rank)
197-
label = [None]*len(label_name)
198-
print(rank, len(label_name), list(map(os.path.basename, label_name)))
199-
200-
if _load_label_condition(valid_mask_name, mode, image_only_test):
201-
valid_mask_name = _get_file_list(valid_mask_name, prefix=dir_path)
202-
valid_mask_name = _make_path_list(cfg, dir_name, valid_mask_name, rank)
203-
valid_mask = [None]*len(valid_mask_name)
204173

205-
pad_mode = cfg.DATASET.PAD_MODE
206-
read_fn = readvol if not cfg.DATASET.LOAD_2D else readimg_as_vol
207-
num_vols = len(img_name) if img_name is not None else len(label_name)
174+
if preload_data is not None:
175+
volume = preload_data
176+
num_vols = len(preload_data)
177+
else:
178+
dir_path = cfg.DATASET.INPUT_PATH
179+
if dir_name_init is not None:
180+
dir_name = dir_name_init
181+
else:
182+
dir_name = _get_file_list(dir_path)
183+
184+
if mode == 'val':
185+
img_name = cfg.DATASET.VAL_IMAGE_NAME
186+
label_name = cfg.DATASET.VAL_LABEL_NAME
187+
valid_mask_name = cfg.DATASET.VAL_VALID_MASK_NAME
188+
else:
189+
img_name = cfg.DATASET.IMAGE_NAME
190+
label_name = cfg.DATASET.LABEL_NAME
191+
valid_mask_name = cfg.DATASET.VALID_MASK_NAME
192+
assert not all([elem == None for elem in [img_name, label_name]]), \
193+
"At least one of img_name and label_name should not be None!"
194+
195+
if img_name_init is not None:
196+
img_name = img_name_init
197+
198+
if img_name is not None:
199+
img_name = _get_file_list(img_name, prefix=dir_path)
200+
img_name = _make_path_list(cfg, dir_name, img_name, rank)
201+
volume = [None] * len(img_name)
202+
print(rank, len(img_name), list(map(os.path.basename, img_name)))
203+
204+
if _load_label_condition(label_name, mode, image_only_test):
205+
label_name = _get_file_list(label_name, prefix=dir_path)
206+
label_name = _make_path_list(cfg, dir_name, label_name, rank)
207+
label = [None]*len(label_name)
208+
print(rank, len(label_name), list(map(os.path.basename, label_name)))
209+
210+
if _load_label_condition(valid_mask_name, mode, image_only_test):
211+
valid_mask_name = _get_file_list(valid_mask_name, prefix=dir_path)
212+
valid_mask_name = _make_path_list(cfg, dir_name, valid_mask_name, rank)
213+
valid_mask = [None]*len(valid_mask_name)
214+
read_fn = readvol if not cfg.DATASET.LOAD_2D else readimg_as_vol
215+
num_vols = len(img_name) if img_name is not None else len(label_name)
208216

209217
for i in range(num_vols):
210218
if volume is not None:
211-
212-
volume[i] = read_fn(img_name[i], drop_channel=cfg.DATASET.DROP_CHANNEL)
219+
if preload_data is None:
220+
volume[i] = read_fn(img_name[i], drop_channel=cfg.DATASET.DROP_CHANNEL)
213221
print(f"volume shape (original): {volume[i].shape}")
214222
if cfg.DATASET.NORMALIZE_RANGE:
215223
volume[i] = normalize_range(volume[i])
@@ -257,8 +265,7 @@ def get_dataset(cfg,
257265
dataset_options={},
258266
dir_name_init: Optional[list] = None,
259267
img_name_init: Optional[list] = None,
260-
tensorstore_data = None,
261-
tensorstore_coord: Optional[list] = None):
268+
preload_data = None):
262269
r"""Prepare dataset for training and inference.
263270
"""
264271
assert mode in ['train', 'val', 'test']
@@ -340,14 +347,12 @@ def _make_json_path(path, name):
340347
**shared_kwargs)
341348

342349
else: # build VolumeDataset or VolumeDatasetMultiSeg
343-
if tensorstore_data is None:
350+
if preload_data is None:
344351
volume, label, valid_mask = _get_input(
345352
cfg, mode, rank, dir_name_init, img_name_init, min_size=sample_volume_size)
346353
else:
347-
volume = [tensorstore_data[coord[0]:coord[1],coord[2]:coord[3],coord[4]:coord[5]].read().result().transpose() \
348-
for coord in tensorstore_coord]
349-
label = None
350-
valid_mask = None
354+
volume, label, valid_mask = _get_input(
355+
cfg, mode, rank, preload_data=preload_data, min_size=sample_volume_size)
351356

352357
if cfg.MODEL.TARGET_OPT_MULTISEG_SPLIT is not None:
353358
shared_kwargs['multiseg_split'] = cfg.MODEL.TARGET_OPT_MULTISEG_SPLIT

connectomics/engine/trainer.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,24 +271,26 @@ def test(self):
271271
print('Final prediction shapes are:')
272272
for k in range(len(result)):
273273
print(result[k].shape)
274+
if not os.path.exists(self.output_dir):
275+
os.makedirs(self.output_dir)
274276
save_path = os.path.join(self.output_dir, self.test_filename)
275277
writeh5(save_path, result, ['vol%d' % (x) for x in range(len(result))])
276278
print('Prediction saved as: ', save_path)
277279

278280
def test_singly(self):
279281
dir_name = None
280-
if self.cfg.INFERENCE.TENSORSTORE_PATH is None:
281-
dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH)
282-
assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True
283-
img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0])
284-
else:
282+
if self.cfg.INFERENCE.TENSORSTORE_PATH is not None:
285283
import tensorstore as ts
286284
context = ts.Context({'cache_pool': {'total_bytes_limit': 1000000000}})
287285
ts_dict = read_pkl(self.cfg.INFERENCE.TENSORSTORE_PATH)
288286
ts_data = ts.open(ts_dict, read=True, context=context).result()[ts.d['channel'][0]]
289287
# chunk coordinate
290288
img_name = np.loadtxt(self.cfg.DATASET.IMAGE_NAME).astype(int)
291-
289+
else:
290+
dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH)
291+
assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True
292+
img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0])
293+
292294
num_file = len(img_name)
293295

294296
if os.path.isfile(os.path.join(self.output_dir, self.cfg.INFERENCE.OUTPUT_NAME)):
@@ -312,16 +314,21 @@ def test_singly(self):
312314
for i in range(self.cfg.INFERENCE.DO_SINGLY_START_INDEX, num_file, self.cfg.INFERENCE.DO_SINGLY_STEP):
313315
self.test_filename = output_name[i]
314316
if not os.path.exists(self.test_filename):
315-
if self.cfg.INFERENCE.TENSORSTORE_PATH is None:
317+
if dir_name is not None:
316318
# directly load from dir_name_init and img_name_init
317319
dataset = get_dataset(
318320
self.cfg, self.augmentor, self.mode, self.rank,
319321
dir_name_init=dir_name, img_name_init=[img_name[i]])
320322
else:
323+
if self.cfg.INFERENCE.TENSORSTORE_PATH is not None:
324+
coord = img_name[i]
325+
preload_data = [ts_data[coord[0]:coord[1],coord[2]:coord[3],coord[4]:coord[5]].read().result().transpose()]
326+
321327
# preload from tensorstore
322328
dataset = get_dataset(
323329
self.cfg, self.augmentor, self.mode, self.rank,
324-
tensorstore_data=ts_data, tensorstore_coord=[img_name[i]])
330+
preload_data=preload_data)
331+
325332
self.dataloader = build_dataloader(
326333
self.cfg, self.augmentor, self.mode, dataset, self.rank)
327334
self.dataloader = iter(self.dataloader)

0 commit comments

Comments
 (0)