Skip to content

Commit 0d763b3

Browse files
author
donglaiw
committed
add support for tensorstore
1 parent 9efa683 commit 0d763b3

File tree

4 files changed

+74
-23
lines changed

4 files changed

+74
-23
lines changed

connectomics/config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@
474474
_C.INFERENCE.INPUT_SIZE = None
475475
_C.INFERENCE.OUTPUT_SIZE = None
476476

477+
_C.INFERENCE.TENSORSTORE_PATH = None
477478
_C.INFERENCE.INPUT_PATH = None
478479
_C.INFERENCE.IMAGE_NAME = None
479480
_C.INFERENCE.OUTPUT_PATH = ""

connectomics/data/dataset/build.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def _validate_shape(cfg, image, mask, i):
208208

209209
for i in range(num_vols):
210210
if volume is not None:
211+
211212
volume[i] = read_fn(img_name[i], drop_channel=cfg.DATASET.DROP_CHANNEL)
212213
print(f"volume shape (original): {volume[i].shape}")
213214
if cfg.DATASET.NORMALIZE_RANGE:
@@ -255,7 +256,9 @@ def get_dataset(cfg,
255256
dataset_class=VolumeDataset,
256257
dataset_options={},
257258
dir_name_init: Optional[list] = None,
258-
img_name_init: Optional[list] = None):
259+
img_name_init: Optional[list] = None,
260+
tensorstore_data = None,
261+
tensorstore_coord: Optional[list] = None):
259262
r"""Prepare dataset for training and inference.
260263
"""
261264
assert mode in ['train', 'val', 'test']
@@ -337,8 +340,14 @@ def _make_json_path(path, name):
337340
**shared_kwargs)
338341

339342
else: # build VolumeDataset or VolumeDatasetMultiSeg
340-
volume, label, valid_mask = _get_input(
341-
cfg, mode, rank, dir_name_init, img_name_init, min_size=sample_volume_size)
343+
if tensorstore_data is None:
344+
volume, label, valid_mask = _get_input(
345+
cfg, mode, rank, dir_name_init, img_name_init, min_size=sample_volume_size)
346+
else:
347+
volume = [tensorstore_data[coord[0]:coord[1],coord[2]:coord[3],coord[4]:coord[5]].read().result().transpose() \
348+
for coord in tensorstore_coord]
349+
label = None
350+
valid_mask = None
342351

343352
if cfg.MODEL.TARGET_OPT_MULTISEG_SPLIT is not None:
344353
shared_kwargs['multiseg_split'] = cfg.MODEL.TARGET_OPT_MULTISEG_SPLIT

connectomics/data/utils/data_io.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import glob
1212
import numpy as np
1313
import imageio
14+
import pickle
1415
from scipy.ndimage import zoom
1516

1617

@@ -110,6 +111,24 @@ def readimgs(filename):
110111

111112
return data
112113

114+
def read_pkl(filename):
115+
"""
116+
The function `read_pkl` reads a pickle file and returns a list of the objects stored in the file.
117+
118+
:param filename: The filename parameter is a string that represents the name of the file you want to
119+
read. It should include the file extension, such as ".pkl" for a pickle file
120+
:return: a list of objects that were read from the pickle file.
121+
"""
122+
data = []
123+
with open(filename, "rb") as fid:
124+
while True:
125+
try:
126+
data.append(pickle.load(fid))
127+
except EOFError:
128+
break
129+
if len(data) == 1:
130+
return data[0]
131+
return data
113132

114133
def writeh5(filename, dtarray, dataset='main'):
115134
fid = h5py.File(filename, 'w')

connectomics/engine/trainer.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import time
77
import math
8+
import pickle
89
import GPUtil
910
import numpy as np
1011
from yacs.config import CfgNode
@@ -19,7 +20,7 @@
1920
from ..data.augmentation import build_train_augmentor, TestAugmentor
2021
from ..data.dataset import build_dataloader, get_dataset
2122
from ..data.dataset.build import _get_file_list
22-
from ..data.utils import build_blending_matrix, writeh5
23+
from ..data.utils import build_blending_matrix, writeh5, read_pkl
2324
from ..data.utils import get_padsize, array_unpad
2425

2526

@@ -272,32 +273,53 @@ def test(self):
272273
writeh5(save_path, result, ['vol%d' % (x) for x in range(len(result))])
273274
print('Prediction saved as: ', save_path)
274275

275-
def test_singly(self):
276-
dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH)
277-
assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True
278-
img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0])
279-
num_file = len(img_name)
280-
281-
if os.path.isfile(self.cfg.INFERENCE.OUTPUT_NAME):
282-
output_name = _get_file_list(self.cfg.DATASET.OUTPUT_NAME, prefix=self.output_dir)
276+
def test_singly(self):
277+
dir_name = None
278+
if self.cfg.INFERENCE.TENSORSTORE_PATH is None:
279+
dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH)
280+
assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True
281+
img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0])
283282
else:
284-
# same filename but different location
285-
if self.output_dir != dir_name[0]:
286-
output_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=self.output_dir)
283+
import tensorstore as ts
284+
context = ts.Context({'cache_pool': {'total_bytes_limit': 1000000000}})
285+
ts_dict = read_pkl(self.cfg.INFERENCE.TENSORSTORE_PATH)
286+
ts_data = ts.open(ts_dict, read=True, context=context).result()[ts.d['channel'][0]]
287+
# chunk coordinate
288+
img_name = np.loadtxt(self.cfg.DATASET.IMAGE_NAME).astype(int)
289+
290+
num_file = len(img_name)
291+
292+
if os.path.isfile(os.path.join(self.output_dir, self.cfg.INFERENCE.OUTPUT_NAME)):
293+
# load output names
294+
output_name = _get_file_list(self.cfg.INFERENCE.OUTPUT_NAME, prefix=self.output_dir)
295+
else:
296+
if dir_name is None or self.output_dir != dir_name[0]:
297+
# same filenames but different location
298+
if '{' in self.cfg.INFERENCE.OUTPUT_NAME:
299+
# template function
300+
output_name = [None] * num_file
301+
for i in range(num_file):
302+
arr = img_name[i]
303+
output_name[i] = os.path.join(self.output_dir, eval(self.cfg.INFERENCE.OUTPUT_NAME)+'.h5')
304+
else:
305+
output_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=self.output_dir)
287306
else:
307+
# same file location
288308
output_name = [x+'_result.h5' for x in img_name]
289309

290-
# save input image names for future reference
291-
fw = open(os.path.join(self.output_dir, "images.txt"), "w")
292-
fw.write('\n'.join(img_name))
293-
fw.close()
294-
295310
for i in range(self.cfg.INFERENCE.DO_SINGLY_START_INDEX, num_file, self.cfg.INFERENCE.DO_SINGLY_STEP):
296311
self.test_filename = output_name[i]
297312
if not os.path.exists(self.test_filename):
298-
dataset = get_dataset(
299-
self.cfg, self.augmentor, self.mode, self.rank,
300-
dir_name_init=dir_name, img_name_init=[img_name[i]])
313+
if self.cfg.INFERENCE.TENSORSTORE_PATH is None:
314+
# directly load from dir_name_init and img_name_init
315+
dataset = get_dataset(
316+
self.cfg, self.augmentor, self.mode, self.rank,
317+
dir_name_init=dir_name, img_name_init=[img_name[i]])
318+
else:
319+
# preload from tensorstore
320+
dataset = get_dataset(
321+
self.cfg, self.augmentor, self.mode, self.rank,
322+
tensorstore_data=ts_data, tensorstore_coord=[img_name[i]])
301323
self.dataloader = build_dataloader(
302324
self.cfg, self.augmentor, self.mode, dataset, self.rank)
303325
self.dataloader = iter(self.dataloader)

0 commit comments

Comments
 (0)