1- from typing import Optional , Tuple
1+ """ CvT: Convolutional Vision Transformer
2+
3+ From-scratch implementation of CvT in PyTorch
4+
5+ 'CvT: Introducing Convolutions to Vision Transformers'
6+ - https://arxiv.org/abs/2103.15808
7+
8+ Implementation for timm by / Copyright 2024, Fredo Guan
9+ """
10+
11+ from functools import partial
12+ from typing import List , Final , Optional , Tuple
213
314import torch
4- import torch .nn
15+ import torch .nn as nn
516import torch .nn .functional as F
617
7- from timm .layers import ConvNormAct , LayerNorm2d , Mlp , QuickGELU , trunc_normal_ , use_fused_attn
18+ from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
19+ from timm .layers import ConvNormAct , LayerNorm , LayerNorm2d , Mlp , QuickGELU , trunc_normal_ , use_fused_attn
20+ from ._builder import build_model_with_cfg
21+ from ._registry import generate_default_cfgs , register_model
22+
23+
824
25+ __all__ = ['CvT' ]
926
1027class ConvEmbed (nn .Module ):
1128 def __init__ (
@@ -15,7 +32,7 @@ def __init__(
1532 kernel_size : int = 7 ,
1633 stride : int = 4 ,
1734 padding : int = 2 ,
18- norm_layer : nn .Module = nn . LayerNorm2d ,
35+ norm_layer : nn .Module = LayerNorm2d ,
1936 ) -> None :
2037 super ().__init__ ()
2138
@@ -44,7 +61,7 @@ def __init__(
4461 padding : int = 1 ,
4562 bias : bool = False ,
4663 norm_layer : nn .Module = nn .BatchNorm2d ,
47- act_layer : nn .Module = nn .Identity () ,
64+ act_layer : nn .Module = nn .Identity ,
4865 ) -> None :
4966 super ().__init__ ()
5067 self .dim = dim
@@ -55,7 +72,7 @@ def __init__(
5572 kernel_size ,
5673 stride = stride_q ,
5774 padding = padding ,
58- groups = in_chs ,
75+ groups = dim ,
5976 bias = bias ,
6077 norm_layer = norm_layer ,
6178 act_layer = act_layer
@@ -70,8 +87,8 @@ def __init__(
7087 kernel_size ,
7188 stride = stride_kv ,
7289 padding = padding ,
73- groups = in_chs ,
74- bias = conv_bias ,
90+ groups = dim ,
91+ bias = bias ,
7592 norm_layer = norm_layer ,
7693 act_layer = act_layer
7794 )
@@ -82,8 +99,8 @@ def __init__(
8299 kernel_size ,
83100 stride = stride_kv ,
84101 padding = padding ,
85- groups = in_chs ,
86- bias = conv_bias ,
102+ groups = dim ,
103+ bias = bias ,
87104 norm_layer = norm_layer ,
88105 act_layer = act_layer
89106 )
@@ -107,7 +124,7 @@ def __init__(
107124 qk_norm : bool = False ,
108125 attn_drop : float = 0. ,
109126 proj_drop : float = 0. ,
110- norm_layer : nn .Module = nn . LayerNorm ,
127+ norm_layer : nn .Module = LayerNorm ,
111128 ) -> None :
112129 super ().__init__ ()
113130 assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
@@ -122,16 +139,16 @@ def __init__(
122139 self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
123140 self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
124141 self .attn_drop = nn .Dropout (attn_drop )
125- self .proj = nn .Linear (out_chs , out_chs )
142+ self .proj = nn .Linear (dim , dim )
126143 self .proj_drop = nn .Dropout (proj_drop )
127144
128145 def forward (self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ) -> torch .Tensor :
129146 B , N , C = q .shape
130147
131148 # [B, H*W, C] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h]
132- q = self .proj_q (q ).reshape (B , q .shape [2 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
133- k = self .proj_k (k ).reshape (B , k .shape [2 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
134- v = self .proj_v (v ).reshape (B , v .shape [2 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
149+ q = self .proj_q (q ).reshape (B , q .shape [1 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
150+ k = self .proj_k (k ).reshape (B , k .shape [1 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
151+ v = self .proj_v (v ).reshape (B , v .shape [1 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
135152 q , k = self .q_norm (q ), self .k_norm (k )
136153 # [B, n_h, H*W, d_h], [B, n_h, H*W/4, d_h], [B, n_h, H*W/4, d_h]
137154
@@ -162,14 +179,14 @@ def __init__(
162179 padding : int = 1 ,
163180 conv_bias : bool = False ,
164181 conv_norm_layer : nn .Module = nn .BatchNorm2d ,
165- conv_act_layer : nn .Module = nn .Identity () ,
182+ conv_act_layer : nn .Module = nn .Identity ,
166183 num_heads : int = 1 ,
167184 qkv_bias : bool = True ,
168185 qk_norm : bool = False ,
169186 attn_drop : float = 0. ,
170187 proj_drop : float = 0. ,
171- input_norm_layer = LayerNorm2d ,
172- norm_layer : nn .Module = nn . LayerNorm ,
188+ input_norm_layer : nn . Module = partial ( LayerNorm2d , eps = 1e-5 ) ,
189+ norm_layer : nn .Module = partial ( LayerNorm , eps = 1e-5 ) ,
173190 init_values : Optional [float ] = None ,
174191 drop_path : float = 0. ,
175192 mlp_layer : nn .Module = Mlp ,
@@ -180,7 +197,7 @@ def __init__(
180197 super ().__init__ ()
181198 self .use_cls_token = use_cls_token
182199
183- self .norm1 = norm_layer (dim )
200+ self .norm1 = input_norm_layer (dim )
184201 self .conv_proj = ConvProj (
185202 dim = dim ,
186203 kernel_size = kernel_size ,
@@ -207,7 +224,7 @@ def __init__(
207224 self .mlp = mlp_layer (
208225 in_features = dim ,
209226 hidden_features = int (dim * mlp_ratio ),
210- act_layer = act_layer ,
227+ act_layer = mlp_act_layer ,
211228 drop = proj_drop ,
212229 )
213230 self .ls2 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
@@ -232,7 +249,8 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T
232249 def forward (self , x : torch .Tensor , cls_token : Optional [torch .Tensor ]) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
233250 B , C , H , W = x .shape
234251
235- x = x .flatten (2 ).transpose (1 , 2 ) + self .drop_path1 (self .ls1 (self .fw_attn (self .norm1 (x ))))
252+ x = torch .cat ((cls_token , x .flatten (2 ).transpose (1 , 2 )), dim = 1 ) if cls_token is not None else x .flatten (2 ).transpose (1 , 2 ) \
253+ + self .drop_path1 (self .ls1 (self .fw_attn (self .norm1 (x ), cls_token )))
236254 x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
237255
238256 if self .use_cls_token :
@@ -244,6 +262,7 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t
244262
245263class CvTStage (nn .Module ):
246264 def __init__ (
265+ self ,
247266 in_chs : int ,
248267 dim : int ,
249268 depth : int ,
@@ -256,14 +275,14 @@ def __init__(
256275 padding : int = 1 ,
257276 conv_bias : bool = False ,
258277 conv_norm_layer : nn .Module = nn .BatchNorm2d ,
259- conv_act_layer : nn .Module = nn .Identity () ,
278+ conv_act_layer : nn .Module = nn .Identity ,
260279 num_heads : int = 1 ,
261280 qkv_bias : bool = True ,
262281 qk_norm : bool = False ,
263282 attn_drop : float = 0. ,
264283 proj_drop : float = 0. ,
265284 input_norm_layer = LayerNorm2d ,
266- norm_layer : nn .Module = nn . LayerNorm ,
285+ norm_layer : nn .Module = LayerNorm ,
267286 init_values : Optional [float ] = None ,
268287 drop_path_rates : List [float ] = [0. ],
269288 mlp_layer : nn .Module = Mlp ,
@@ -320,14 +339,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
320339 x = self .conv_embed (x )
321340 x = self .embed_drop (x )
322341
323- cls_token = self .cls_token
342+ cls_token = self .cls_token . expand ( x . shape [ 0 ], - 1 , - 1 ) if self . cls_token is not None else None
324343 for block in self .blocks : # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor
325344 x , cls_token = block (x , cls_token )
326345
327346 return x , cls_token
328347
329348class CvT (nn .Module ):
330349 def __init__ (
350+ self ,
331351 in_chans : int = 3 ,
332352 num_classes : int = 1000 ,
333353 dims : Tuple [int , ...] = (64 , 192 , 384 ),
@@ -341,14 +361,14 @@ def __init__(
341361 padding : int = 1 ,
342362 conv_bias : bool = False ,
343363 conv_norm_layer : nn .Module = nn .BatchNorm2d ,
344- conv_act_layer : nn .Module = nn .Identity () ,
364+ conv_act_layer : nn .Module = nn .Identity ,
345365 num_heads : Tuple [int , ...] = (1 , 3 , 6 ),
346366 qkv_bias : bool = True ,
347367 qk_norm : bool = False ,
348368 attn_drop : float = 0. ,
349369 proj_drop : float = 0. ,
350370 input_norm_layer = LayerNorm2d ,
351- norm_layer : nn .Module = nn . LayerNorm ,
371+ norm_layer : nn .Module = LayerNorm ,
352372 init_values : Optional [float ] = None ,
353373 drop_path_rate : float = 0. ,
354374 mlp_layer : nn .Module = Mlp ,
@@ -362,7 +382,6 @@ def __init__(
362382 assert num_stages == len (embed_padding ) == len (num_heads ) == len (use_cls_token )
363383 self .num_classes = num_classes
364384 self .num_features = dims [- 1 ]
365- self .drop_rate = drop_rate
366385
367386 # FIXME only on last stage, no need for tuple
368387 self .use_cls_token = use_cls_token [- 1 ]
@@ -371,6 +390,8 @@ def __init__(
371390
372391 in_chs = in_chans
373392
393+ # TODO move stem
394+
374395 stages = []
375396 for stage_idx in range (num_stages ):
376397 dim = dims [stage_idx ]
@@ -406,7 +427,7 @@ def __init__(
406427 stages .append (stage )
407428 self .stages = nn .ModuleList (stages )
408429
409- self .head_norm = norm_layer (dims [- 1 ])
430+ self .norm = norm_layer (dims [- 1 ])
410431 self .head = nn .Linear (dims [- 1 ], num_classes ) if num_classes > 0 else nn .Identity ()
411432
412433 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -416,8 +437,62 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
416437
417438
418439 if self .use_cls_token :
419- return self .head (self .head_norm (cls_token ))
440+ return self .head (self .norm (cls_token . flatten ( 1 ) ))
420441 else :
421- return self .head (self .head_norm (x .mean (dim = (2 ,3 ))))
442+ return self .head (self .norm (x .mean (dim = (2 ,3 ))))
422443
423-
444+
445+
446+ def checkpoint_filter_fn (state_dict , model ):
447+ """ Remap MSFT checkpoints -> timm """
448+ if 'head.fc.weight' in state_dict :
449+ return state_dict # non-MSFT checkpoint
450+
451+ if 'state_dict' in state_dict :
452+ state_dict = state_dict ['state_dict' ]
453+
454+ import re
455+ out_dict = {}
456+ for k , v in state_dict .items ():
457+ k = re .sub (r'stage([0-9]+)' , r'stages.\1' , k )
458+ k = k .replace ('patch_embed' , 'conv_embed' )
459+ k = k .replace ('conv_embed.proj' , 'conv_embed.conv' )
460+ k = k .replace ('attn.conv_proj' , 'conv_proj.conv' )
461+ out_dict [k ] = v
462+ return out_dict
463+
464+
465+ def _create_cvt (variant , pretrained = False , ** kwargs ):
466+ default_out_indices = tuple (i for i , _ in enumerate (kwargs .get ('depths' , (1 , 2 , 10 ))))
467+ out_indices = kwargs .pop ('out_indices' , default_out_indices )
468+
469+ model = build_model_with_cfg (
470+ CvT ,
471+ variant ,
472+ pretrained ,
473+ pretrained_filter_fn = checkpoint_filter_fn ,
474+ feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
475+ ** kwargs )
476+
477+ return model
478+
479+ # TODO update first_conv
480+ def _cfg (url = '' , ** kwargs ):
481+ return {
482+ 'url' : url ,
483+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : (14 , 14 ),
484+ 'crop_pct' : 0.95 , 'interpolation' : 'bicubic' ,
485+ 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
486+ 'first_conv' : 'stem.conv' , 'classifier' : 'head' ,
487+ ** kwargs
488+ }
489+
490+ default_cfgs = generate_default_cfgs ({
491+ 'cvt_13.msft_in1k' : _cfg (url = 'https://files.catbox.moe/xz97kh.pth' ),
492+ })
493+
494+
495+ @register_model
496+ def cvt_13 (pretrained = False , ** kwargs ) -> CvT :
497+ model_args = dict (depths = (1 , 2 , 10 ), dims = (64 , 192 , 384 ), num_heads = (1 , 3 , 6 ))
498+ return _create_cvt ('cvt_13' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments