@@ -146,8 +146,8 @@ def __init__(
146146
147147 self .encoder3 = UnetrBasicBlock (
148148 spatial_dims = spatial_dims ,
149- in_channels = 2 * feature_size ,
150- out_channels = 2 * feature_size ,
149+ in_channels = feature_size * 2 ,
150+ out_channels = feature_size * 2 ,
151151 kernel_size = 3 ,
152152 stride = 1 ,
153153 norm_name = norm_name ,
@@ -156,8 +156,8 @@ def __init__(
156156
157157 self .encoder4 = UnetrBasicBlock (
158158 spatial_dims = spatial_dims ,
159- in_channels = 4 * feature_size ,
160- out_channels = 4 * feature_size ,
159+ in_channels = feature_size * 4 ,
160+ out_channels = feature_size * 4 ,
161161 kernel_size = 3 ,
162162 stride = 1 ,
163163 norm_name = norm_name ,
@@ -166,8 +166,8 @@ def __init__(
166166
167167 self .encoder5 = UnetrBasicBlock (
168168 spatial_dims = spatial_dims ,
169- in_channels = 8 * feature_size ,
170- out_channels = 8 * feature_size ,
169+ in_channels = feature_size * 8 ,
170+ out_channels = feature_size * 8 ,
171171 kernel_size = 3 ,
172172 stride = 1 ,
173173 norm_name = norm_name ,
@@ -176,8 +176,8 @@ def __init__(
176176
177177 self .encoder6 = UnetrBasicBlock (
178178 spatial_dims = spatial_dims ,
179- in_channels = 16 * feature_size ,
180- out_channels = 16 * feature_size ,
179+ in_channels = feature_size * 16 ,
180+ out_channels = feature_size * 16 ,
181181 kernel_size = 3 ,
182182 stride = 1 ,
183183 norm_name = norm_name ,
@@ -186,8 +186,8 @@ def __init__(
186186
187187 self .decoder5 = UnetrUpBlock (
188188 spatial_dims = spatial_dims ,
189- in_channels = 16 * feature_size ,
190- out_channels = 8 * feature_size ,
189+ in_channels = feature_size * 16 ,
190+ out_channels = feature_size * 8 ,
191191 kernel_size = 3 ,
192192 upsample_kernel_size = up_kernel_size [4 ],
193193 norm_name = norm_name ,
@@ -495,7 +495,7 @@ def forward(self, x, mask):
495495 else :
496496 attn = self .softmax (attn )
497497
498- attn = self .attn_drop (attn )
498+ attn = self .attn_drop (attn ). to ( v . dtype )
499499 x = (attn @ v ).transpose (1 , 2 ).reshape (b , n , c )
500500 x = self .proj (x )
501501 x = self .proj_drop (x )
@@ -1005,7 +1005,7 @@ def __init__(
10051005 self .layers4 .append (layer )
10061006 if self .use_v2 :
10071007 layerc = UnetrBasicBlock (
1008- spatial_dims = 3 ,
1008+ spatial_dims = spatial_dims ,
10091009 in_channels = embed_dim * 2 ** i_layer ,
10101010 out_channels = embed_dim * 2 ** i_layer ,
10111011 kernel_size = 3 ,
0 commit comments