|
5 | 5 | import os |
6 | 6 | import time |
7 | 7 | import math |
| 8 | +import pickle |
8 | 9 | import GPUtil |
9 | 10 | import numpy as np |
10 | 11 | from yacs.config import CfgNode |
|
19 | 20 | from ..data.augmentation import build_train_augmentor, TestAugmentor |
20 | 21 | from ..data.dataset import build_dataloader, get_dataset |
21 | 22 | from ..data.dataset.build import _get_file_list |
22 | | -from ..data.utils import build_blending_matrix, writeh5 |
| 23 | +from ..data.utils import build_blending_matrix, writeh5, read_pkl |
23 | 24 | from ..data.utils import get_padsize, array_unpad |
24 | 25 |
|
25 | 26 |
|
@@ -272,32 +273,53 @@ def test(self): |
272 | 273 | writeh5(save_path, result, ['vol%d' % (x) for x in range(len(result))]) |
273 | 274 | print('Prediction saved as: ', save_path) |
274 | 275 |
|
275 | | - def test_singly(self): |
276 | | - dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH) |
277 | | - assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True |
278 | | - img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0]) |
279 | | - num_file = len(img_name) |
280 | | - |
281 | | - if os.path.isfile(self.cfg.INFERENCE.OUTPUT_NAME): |
282 | | - output_name = _get_file_list(self.cfg.DATASET.OUTPUT_NAME, prefix=self.output_dir) |
| 276 | + def test_singly(self): |
| 277 | + dir_name = None |
| 278 | + if self.cfg.INFERENCE.TENSORSTORE_PATH is None: |
| 279 | + dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH) |
| 280 | + assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True |
| 281 | + img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0]) |
283 | 282 | else: |
284 | | - # same filename but different location |
285 | | - if self.output_dir != dir_name[0]: |
286 | | - output_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=self.output_dir) |
| 283 | + import tensorstore as ts |
| 284 | + context = ts.Context({'cache_pool': {'total_bytes_limit': 1000000000}}) |
| 285 | + ts_dict = read_pkl(self.cfg.INFERENCE.TENSORSTORE_PATH) |
| 286 | + ts_data = ts.open(ts_dict, read=True, context=context).result()[ts.d['channel'][0]] |
| 287 | + # chunk coordinate |
| 288 | + img_name = np.loadtxt(self.cfg.DATASET.IMAGE_NAME).astype(int) |
| 289 | + |
| 290 | + num_file = len(img_name) |
| 291 | + |
| 292 | + if os.path.isfile(os.path.join(self.output_dir, self.cfg.INFERENCE.OUTPUT_NAME)): |
| 293 | + # load output names |
| 294 | + output_name = _get_file_list(self.cfg.INFERENCE.OUTPUT_NAME, prefix=self.output_dir) |
| 295 | + else: |
| 296 | + if dir_name is None or self.output_dir != dir_name[0]: |
| 297 | + # same filenames but different location |
| 298 | + if '{' in self.cfg.INFERENCE.OUTPUT_NAME: |
| 299 | + # template function |
| 300 | + output_name = [None] * num_file |
| 301 | + for i in range(num_file): |
| 302 | + arr = img_name[i] |
| 303 | + output_name[i] = os.path.join(self.output_dir, eval(self.cfg.INFERENCE.OUTPUT_NAME)+'.h5') |
| 304 | + else: |
| 305 | + output_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=self.output_dir) |
287 | 306 | else: |
| 307 | + # same file location |
288 | 308 | output_name = [x+'_result.h5' for x in img_name] |
289 | 309 |
|
290 | | - # save input image names for future reference |
291 | | - fw = open(os.path.join(self.output_dir, "images.txt"), "w") |
292 | | - fw.write('\n'.join(img_name)) |
293 | | - fw.close() |
294 | | - |
295 | 310 | for i in range(self.cfg.INFERENCE.DO_SINGLY_START_INDEX, num_file, self.cfg.INFERENCE.DO_SINGLY_STEP): |
296 | 311 | self.test_filename = output_name[i] |
297 | 312 | if not os.path.exists(self.test_filename): |
298 | | - dataset = get_dataset( |
299 | | - self.cfg, self.augmentor, self.mode, self.rank, |
300 | | - dir_name_init=dir_name, img_name_init=[img_name[i]]) |
| 313 | + if self.cfg.INFERENCE.TENSORSTORE_PATH is None: |
| 314 | + # directly load from dir_name_init and img_name_init |
| 315 | + dataset = get_dataset( |
| 316 | + self.cfg, self.augmentor, self.mode, self.rank, |
| 317 | + dir_name_init=dir_name, img_name_init=[img_name[i]]) |
| 318 | + else: |
| 319 | + # preload from tensorstore |
| 320 | + dataset = get_dataset( |
| 321 | + self.cfg, self.augmentor, self.mode, self.rank, |
| 322 | + tensorstore_data=ts_data, tensorstore_coord=[img_name[i]]) |
301 | 323 | self.dataloader = build_dataloader( |
302 | 324 | self.cfg, self.augmentor, self.mode, dataset, self.rank) |
303 | 325 | self.dataloader = iter(self.dataloader) |
|
0 commit comments