Skip to content

Commit e2d4d81

Browse files
author
Donglai Wei
committed
remove get_valid_pos due to memory issue
1 parent f03acf7 commit e2d4d81

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

connectomics/data/dataset/dataset_volume.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, List
22
import numpy as np
33
import random
4+
import warnings
45

56
import torch
67
import torch.utils.data
@@ -68,6 +69,7 @@ def __init__(self,
6869
do_relabel: bool = True,
6970
# rejection sampling
7071
reject_size_thres: int = 0,
72+
reject_num_trial: int = 50,
7173
reject_diversity: int = 0,
7274
reject_p: float = 0.95,
7375
# normalization
@@ -98,6 +100,7 @@ def __init__(self,
98100
# rejection samping
99101
self.reject_size_thres = reject_size_thres
100102
self.reject_diversity = reject_diversity
103+
self.reject_num_trial = reject_num_trial
101104
self.reject_p = reject_p
102105

103106
# normalization
@@ -113,6 +116,17 @@ def __init__(self,
113116
assert len(set(x[0] for x in volume_size)) == 1, "All volumes should have the same number of channels"
114117
self.volume_size = [x[-3:] for x in volume_size]
115118

119+
volume_selection = [(sample_label_size <= x).all() for x in self.volume_size]
120+
if not all(volume_selection):
121+
print('remove volumes whose sizes are smaller than the model input', volume_selection)
122+
self.volume = [x for i,x in enumerate(self.volume) if volume_selection[i]]
123+
volume_size = [np.array(x.shape) for x in self.volume]
124+
self.volume_size = [x[-3:] for x in volume_size]
125+
if self.label is not None:
126+
self.label = [x for i,x in enumerate(self.label) if volume_selection[i]]
127+
if valid_mask is not None:
128+
valid_mask = [x for i,x in enumerate(valid_mask) if volume_selection[i]]
129+
116130
self.sample_volume_size = np.array(
117131
sample_volume_size).astype(int) # model input size
118132
if self.label is not None:
@@ -122,7 +136,7 @@ def __init__(self,
122136
if self.augmentor is not None:
123137
assert np.array_equal(
124138
self.augmentor.sample_size, self.sample_label_size)
125-
self._assert_valid_shape()
139+
#self._assert_valid_shape()
126140

127141
# compute number of samples for each dataset (multi-volume input)
128142
self.sample_stride = np.array(sample_stride).astype(int)
@@ -138,14 +152,18 @@ def __init__(self,
138152
self.valid_mask = valid_mask
139153
self.valid_ratio = valid_ratio
140154
# precompute valid region
155+
# can be memory intensive
141156
self.valid_pos = [None] * len(self.valid_mask)
157+
"""
142158
if self.valid_mask is not None:
143159
for i, x in enumerate(self.valid_mask):
144160
if x is not None:
145161
self.valid_pos[i] = get_valid_pos(x, sample_volume_size, valid_ratio)
146162
self.sample_num[i] = self.valid_pos[i].shape[0]
163+
print(i, self.sample_num[i])
147164
self.sample_num_a = np.sum(self.sample_num)
148165
self.sample_num_c = np.cumsum([0] + list(self.sample_num))
166+
"""
149167

150168
if self.mode in ['val', 'test']: # for validation and test
151169
self.sample_size_test = [
@@ -240,17 +258,17 @@ def _get_pos_train(self, vol_size):
240258
# np.random: same seed
241259
pos = [0, 0, 0, 0]
242260
# pick a dataset
243-
did = self._index_to_dataset(random.randint(0, self.sample_num_a))
261+
did = self._index_to_dataset(random.randint(0, self.sample_num_a - 1))
244262
pos[0] = did
245263
# pick a position
246264
# all regions are valid
247265
if self.valid_pos[did] is None:
248266
tmp_size = count_volume(
249267
self.volume_size[did], vol_size, self.sample_stride)
250-
tmp_pos = [random.randint(0, tmp_size[x]) * self.sample_stride[x]
268+
tmp_pos = [random.randint(0, tmp_size[x] - 1) * self.sample_stride[x]
251269
for x in range(len(tmp_size))]
252270
else:
253-
tmp_pos = self.valid_pos[did][random.randint(0, self.valid_pos[did].shape[0])]
271+
tmp_pos = self.valid_pos[did][random.randint(0, self.valid_pos[did].shape[0]) - 1]
254272

255273
pos[1:] = tmp_pos
256274
return pos
@@ -282,16 +300,21 @@ def _rejection_sampling(self, vol_size):
282300
out_valid = augmented['valid_mask']
283301

284302
if self._is_valid(out_valid) and self._is_fg(out_label):
303+
#print('yes', sample_count)
285304
return pos, out_volume, out_label, out_valid
286305

287306
sample_count += 1
288-
if sample_count > 100:
307+
if sample_count > self.reject_num_trial:
289308
err_msg = (
290309
"Can not find any valid subvolume after sampling the "
291-
"dataset for more than 100 times. Please adjust the "
310+
f"dataset for more than {self.reject_num_trial} times. Please adjust the "
292311
"valid mask or rejection sampling configurations."
293312
)
294-
raise RuntimeError(err_msg)
313+
#raise RuntimeError(err_msg)
314+
# return anyway with a useless sample
315+
warnings.warn(err_msg)
316+
#print('no..')
317+
return pos, out_volume, out_label, out_valid
295318

296319
def _random_sampling(self, vol_size):
297320
"""Randomly sample a subvolume from all the volumes.

0 commit comments

Comments
 (0)