@@ -46,6 +46,7 @@ def __init__(
4646 norm_layer : nn .Module = nn .BatchNorm2d ,
4747 act_layer : nn .Module = nn .Identity (),
4848 ) -> None :
49+ super ().__init__ ()
4950 self .dim = dim
5051
5152 self .conv_q = ConvNormAct (
@@ -98,7 +99,7 @@ class Attention(nn.Module):
9899 def __init__ (
99100 self ,
100101 dim : int ,
101- num_heads : int = 8 ,
102+ num_heads : int = 1 ,
102103 qkv_bias : bool = True ,
103104 qk_norm : bool = False ,
104105 attn_drop : float = 0. ,
@@ -159,7 +160,7 @@ def __init__(
159160 conv_bias : bool = False ,
160161 conv_norm_layer : nn .Module = nn .BatchNorm2d ,
161162 conv_act_layer : nn .Module = nn .Identity (),
162- num_heads : int = 8 ,
163+ num_heads : int = 1 ,
163164 qkv_bias : bool = True ,
164165 qk_norm : bool = False ,
165166 attn_drop : float = 0. ,
@@ -173,6 +174,7 @@ def __init__(
173174 mlp_act_layer : nn .Module = QuickGELU ,
174175 use_cls_token : bool = False ,
175176 ) -> None :
177+ super ().__init__ ()
176178 self .use_cls_token = use_cls_token
177179
178180 self .norm1 = norm_layer (dim )
@@ -242,28 +244,30 @@ def __init__(
242244 depth : int ,
243245 embed_kernel_size : int = 7 ,
244246 embed_stride : int = 4 ,
245- embed_padding : int 2 ,
247+ embed_padding : int = 2 ,
246248 kernel_size : int = 3 ,
247249 stride_q : int = 1 ,
248250 stride_kv : int = 2 ,
249251 padding : int = 1 ,
250252 conv_bias : bool = False ,
251253 conv_norm_layer : nn .Module = nn .BatchNorm2d ,
252254 conv_act_layer : nn .Module = nn .Identity (),
253- num_heads : int = 8 ,
255+ num_heads : int = 1 ,
254256 qkv_bias : bool = True ,
255257 qk_norm : bool = False ,
256258 attn_drop : float = 0. ,
257259 proj_drop : float = 0. ,
258260 input_norm_layer = LayerNorm2d ,
259261 norm_layer : nn .Module = nn .LayerNorm ,
260262 init_values : Optional [float ] = None ,
261- drop_path : float = 0. ,
263+ drop_path_rates : List [ float ] = [ 0. ] ,
262264 mlp_layer : nn .Module = Mlp ,
263265 mlp_ratio : float = 4. ,
264266 mlp_act_layer : nn .Module = QuickGELU ,
265267 use_cls_token : bool = False ,
266268 ) -> None :
269+ super ().__init__ ()
270+
267271 self .conv_embed = ConvEmbed (
268272 in_chs = in_chs ,
269273 out_chs = dim ,
@@ -278,31 +282,30 @@ def __init__(
278282
279283 blocks = []
280284 for i in range (depth ):
281- blocks .append (
282- CvTBlock (
283- dim = dim ,
284- kernel_size = kernel_size ,
285- stride_q = stride_q ,
286- stride_kv = stride_kv ,
287- padding = padding ,
288- conv_bias = conv_bias ,
289- conv_norm_layer = conv_norm_layer ,
290- conv_act_layer = conv_act_layer ,
291- num_heads = num_heads ,
292- qkv_bias = qkv_bias ,
293- qk_norm = qk_norm ,
294- attn_drop = attn_drop ,
295- proj_drop = proj_drop ,
296- input_norm_layer input_norm_layer ,
297- norm_layer = norm_layer ,
298- init_values = init_values ,
299- drop_path = drop_path ,
300- mlp_layer = mlp_layer ,
301- mlp_ratio = mlp_ratio ,
302- mlp_act_layer = mlp_act_layer ,
303- use_cls_token = use_cls_token ,
304- )
285+ block = CvTBlock (
286+ dim = dim ,
287+ kernel_size = kernel_size ,
288+ stride_q = stride_q ,
289+ stride_kv = stride_kv ,
290+ padding = padding ,
291+ conv_bias = conv_bias ,
292+ conv_norm_layer = conv_norm_layer ,
293+ conv_act_layer = conv_act_layer ,
294+ num_heads = num_heads ,
295+ qkv_bias = qkv_bias ,
296+ qk_norm = qk_norm ,
297+ attn_drop = attn_drop ,
298+ proj_drop = proj_drop ,
299+ input_norm_layer input_norm_layer ,
300+ norm_layer = norm_layer ,
301+ init_values = init_values ,
302+ drop_path = drop_path_rates [i ],
303+ mlp_layer = mlp_layer ,
304+ mlp_ratio = mlp_ratio ,
305+ mlp_act_layer = mlp_act_layer ,
306+ use_cls_token = use_cls_token ,
305307 )
308+ blocks .append (block )
306309 self .blocks = nn .ModuleList (blocks )
307310
308311 if self .cls_token is not None :
@@ -313,10 +316,83 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
313316 x = self .embed_drop (x )
314317
315318 cls_token = self .cls_token
316- for block in self .blocks :
319+ for block in self .blocks : # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor
317320 x , cls_token = block (x , cls_token )
318321
319322 return x , cls_token
320323
321324class CvT (nn .Module ):
322-
325+ def __init__ (
326+ in_chans : int = 3 ,
327+ num_classes : int = 1000 ,
328+ dims : Tuple [int , ...] = (64 , 192 , 384 ),
329+ depths : Tuple [int , ...] = (1 , 2 , 10 ),
330+ embed_kernel_size : Tuple [int , ...] = (7 , 3 , 3 ),
331+ embed_stride : Tuple [int , ...] = (4 , 2 , 2 ),
332+ embed_padding : Tuple [int , ...] = (2 , 1 , 1 ),
333+ kernel_size : int = 3 ,
334+ stride_q : int = 1 ,
335+ stride_kv : int = 2 ,
336+ padding : int = 1 ,
337+ conv_bias : bool = False ,
338+ conv_norm_layer : nn .Module = nn .BatchNorm2d ,
339+ conv_act_layer : nn .Module = nn .Identity (),
340+ num_heads : Tuple [int , ...] = (1 , 3 , 6 ),
341+ qkv_bias : bool = True ,
342+ qk_norm : bool = False ,
343+ attn_drop : float = 0. ,
344+ proj_drop : float = 0. ,
345+ input_norm_layer = LayerNorm2d ,
346+ norm_layer : nn .Module = nn .LayerNorm ,
347+ init_values : Optional [float ] = None ,
348+ drop_path_rate : float = 0. ,
349+ mlp_layer : nn .Module = Mlp ,
350+ mlp_ratio : float = 4. ,
351+ mlp_act_layer : nn .Module = QuickGELU ,
352+ use_cls_token : Tuple [bool , ...] = (False , False , True ),
353+ ) -> None :
354+ super ().__init__ ()
355+ num_stages = len (dims )
356+ assert num_stages == len (depths ) == len (embed_kernel_size ) == len (embed_stride )
357+ assert num_stages == len (embed_padding ) == len (num_heads ) == len (use_cls_token )
358+ self .num_classes = num_classes
359+ self .num_features = dims [- 1 ]
360+ self .drop_rate = drop_rate
361+
362+ dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
363+
364+ in_chs = in_chans
365+
366+ stages = []
367+ for stage_idx in range (num_stages ):
368+ dim = dims [stage_idx ]
369+ stage = CvTStage (
370+ in_chs = in_chs ,
371+ dim = dim ,
372+ depth = depths [stage_idx ],
373+ embed_kernel_size = embed_kernel_size [stage_idx ],
374+ embed_stride = embed_stride [stage_idx ],
375+ embed_padding = embed_padding [stage_idx ],
376+ kernel_size = kernel_size ,
377+ stride_q = stride_q ,
378+ stride_kv = stride_kv ,
379+ padding = padding ,
380+ conv_bias = conv_bias ,
381+ conv_norm_layer = conv_norm_layer ,
382+ conv_act_layer = conv_act_layer ,
383+ num_heads = num_heads [stage_idx ],
384+ qkv_bias = qkv_bias ,
385+ qk_norm = qk_norm ,
386+ attn_drop = attn_drop ,
387+ proj_drop = proj_drop ,
388+ input_norm_layer = input_norm_layer ,
389+ norm_layer = norm_layer ,
390+ init_values = init_values ,
391+ drop_path_rates = dpr [stage_idx ],
392+ mlp_layer = mlp_layer ,
393+ mlp_ratio = mlp_ratio ,
394+ mlp_act_layer = mlp_act_layer ,
395+ use_cls_token = use_cls_token [stage_idx ],
396+ )
397+ in_chs = dim
398+ stages .append (stage )
0 commit comments