Skip to content

Commit 61f1d07

Browse files
authored
Merge pull request #139 from zengyuy/master
changed UNETR and SwinUNETR configs; adjust SwinUNETR code
2 parents 1908d23 + bc894f9 commit 61f1d07

File tree

4 files changed

+17
-14
lines changed

4 files changed

+17
-14
lines changed

configs/SNEMI/SNEMI-Affinity-SwinUNETR.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ MODEL:
1414
USE_CHECKPOINT: False
1515
SPATIAL_DIMS: 3
1616
DOWNSAMPLE: 'merging'
17-
USE_V2: False
17+
USE_V2: True
1818
DATASET:
1919
OUTPUT_PATH: outputs/SNEMI_SwinUNETR
2020
SOLVER:

configs/SNEMI/SNEMI-Affinity-UNETR.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ MODEL:
88
UNETR_NUM_HEADS: 12
99
POS_EMBED: 'conv'
1010
NORM_NAME: 'instance'
11-
CONV_BLOCK: True
11+
CONV_BLOCK: False
1212
RES_BLOCK: True
1313
UNETR_DROPOUT_RATE: 0.0
1414
DATASET:

connectomics/config/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@
227227
_C.DATASET.REDUCE_LABEL = True
228228

229229
# Padding size for the input volumes
230+
# Due to the center crop in the data augmentor, regions close to the volume
231+
# border will never be sampled. Therefore we pad the input volume. For
232+
# large-scale dataset there is no need for padding.
230233
_C.DATASET.PAD_SIZE = [2, 64, 64]
231234
_C.DATASET.PAD_MODE = 'reflect' # reflect, constant, symmetric
232235

connectomics/model/arch/swinunetr.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def __init__(
146146

147147
self.encoder3 = UnetrBasicBlock(
148148
spatial_dims=spatial_dims,
149-
in_channels=2 * feature_size,
150-
out_channels=2 * feature_size,
149+
in_channels=feature_size * 2,
150+
out_channels=feature_size * 2,
151151
kernel_size=3,
152152
stride=1,
153153
norm_name=norm_name,
@@ -156,8 +156,8 @@ def __init__(
156156

157157
self.encoder4 = UnetrBasicBlock(
158158
spatial_dims=spatial_dims,
159-
in_channels=4 * feature_size,
160-
out_channels=4 * feature_size,
159+
in_channels=feature_size * 4,
160+
out_channels=feature_size * 4,
161161
kernel_size=3,
162162
stride=1,
163163
norm_name=norm_name,
@@ -166,8 +166,8 @@ def __init__(
166166

167167
self.encoder5 = UnetrBasicBlock(
168168
spatial_dims=spatial_dims,
169-
in_channels=8 * feature_size,
170-
out_channels=8 * feature_size,
169+
in_channels=feature_size * 8,
170+
out_channels=feature_size * 8,
171171
kernel_size=3,
172172
stride=1,
173173
norm_name=norm_name,
@@ -176,8 +176,8 @@ def __init__(
176176

177177
self.encoder6 = UnetrBasicBlock(
178178
spatial_dims=spatial_dims,
179-
in_channels=16 * feature_size,
180-
out_channels=16 * feature_size,
179+
in_channels=feature_size * 16,
180+
out_channels=feature_size * 16,
181181
kernel_size=3,
182182
stride=1,
183183
norm_name=norm_name,
@@ -186,8 +186,8 @@ def __init__(
186186

187187
self.decoder5 = UnetrUpBlock(
188188
spatial_dims=spatial_dims,
189-
in_channels=16 * feature_size,
190-
out_channels=8 * feature_size,
189+
in_channels=feature_size * 16,
190+
out_channels=feature_size * 8,
191191
kernel_size=3,
192192
upsample_kernel_size=up_kernel_size[4],
193193
norm_name=norm_name,
@@ -495,7 +495,7 @@ def forward(self, x, mask):
495495
else:
496496
attn = self.softmax(attn)
497497

498-
attn = self.attn_drop(attn)
498+
attn = self.attn_drop(attn).to(v.dtype)
499499
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
500500
x = self.proj(x)
501501
x = self.proj_drop(x)
@@ -1005,7 +1005,7 @@ def __init__(
10051005
self.layers4.append(layer)
10061006
if self.use_v2:
10071007
layerc = UnetrBasicBlock(
1008-
spatial_dims=3,
1008+
spatial_dims=spatial_dims,
10091009
in_channels=embed_dim * 2**i_layer,
10101010
out_channels=embed_dim * 2**i_layer,
10111011
kernel_size=3,

0 commit comments

Comments
 (0)