1+ from typing import Optional , Tuple
2+
13import torch
24import torch .nn
3- from torch import Tensor
5+ import torch .nn .functional as F
6+
7+ from timm .layers import ConvNormAct , LayerNorm2d , Mlp , QuickGELU , trunc_normal_ , use_fused_attn
48
5- from timm .layers import LayerNorm2d , Mlp , ConvNormAct
69
710class ConvEmbed (nn .Module ):
811 def __init__ (
912 self ,
10- in_chs = 3 ,
11- out_chs = 64 ,
12- kernel_size = 7 ,
13- stride = 4 ,
14- padding = 2 ,
15- norm_layer = LayerNorm2d ,
16- ):
13+ in_chs : int = 3 ,
14+ out_chs : int = 64 ,
15+ kernel_size : int = 7 ,
16+ stride : int = 4 ,
17+ padding : int = 2 ,
18+ norm_layer : nn . Module = nn . LayerNorm2d ,
19+ ) -> None :
1720 super ().__init__ ()
1821
1922 self .conv = nn .Conv2d (
@@ -26,111 +29,294 @@ def __init__(
2629
2730 self .norm = norm_layer (out_chs ) if norm_layer else nn .Identity ()
2831
29- def forward (self , x : Tensor ): # [B, C, H, W] -> [B, C, H, W]
32+ def forward (self , x : torch . Tensor ) -> torch . Tensor : # [B, C, H, W] -> [B, C, H, W]
3033 x = self .conv (x )
3134 x = self .norm (x )
3235 return x
3336
34-
35-
36- class Attention (nn .Module ):
37+ class ConvProj (nn .Module ):
3738 def __init__ (
3839 self ,
39- in_chs ,
40- out_chs ,
41- num_heads ,
42- kernel_size = 3 ,
43- stride_q = 1 ,
44- stride_kv = 1 ,
45- padding_q = 1 ,
46- padding_kv = 1 ,
47- qkv_bias = False ,
48- conv_bias = False ,
49- attn_drop = 0. ,
50- proj_drop = 0. ,
51- conv_norm_layer = nn .BatchNorm2d ,
52- conv_act_layer = nn .Identity (),
53-
54- cls_token = True
55- ):
56- assert out_chs % num_heads == 0 , 'dim should be divisible by num_heads'
57- self .out_chs = out_chs
58- self .num_heads = num_heads
59- self .head_dim = dim // num_heads
60- self .scale = out_chs ** - 0.5
61-
40+ dim : int ,
41+ kernel_size : int = 3 ,
42+ stride_q : int = 1 ,
43+ stride_kv : int = 2 ,
44+ padding : int = 1 ,
45+ bias : bool = False ,
46+ norm_layer : nn .Module = nn .BatchNorm2d ,
47+ act_layer : nn .Module = nn .Identity (),
48+ ) -> None :
49+ self .dim = dim
50+
6251 self .conv_q = ConvNormAct (
63- in_chs ,
64- out_chs ,
52+ dim ,
53+ dim ,
6554 kernel_size ,
6655 stride = stride_q ,
67- padding = padding_q ,
56+ padding = padding ,
6857 groups = in_chs ,
69- bias = conv_bias ,
70- norm_layer = conv_norm_layer ,
71- act_layer = conv_act_layer
58+ bias = bias ,
59+ norm_layer = norm_layer ,
60+ act_layer = act_layer
7261 )
7362
7463 self .conv_k = ConvNormAct (
75- in_chs ,
76- out_chs * 2 ,
64+ dim ,
65+ dim ,
7766 kernel_size ,
7867 stride = stride_kv ,
79- padding = padding_kv ,
68+ padding = padding ,
8069 groups = in_chs ,
8170 bias = conv_bias ,
82- norm_layer = conv_norm_layer ,
83- act_layer = conv_act_layer
71+ norm_layer = norm_layer ,
72+ act_layer = act_layer
8473 )
8574
8675 self .conv_v = ConvNormAct (
87- in_chs ,
88- out_chs * 2 ,
76+ dim ,
77+ dim ,
8978 kernel_size ,
9079 stride = stride_kv ,
91- padding = padding_kv ,
80+ padding = padding ,
9281 groups = in_chs ,
9382 bias = conv_bias ,
94- norm_layer = conv_norm_layer ,
95- act_layer = conv_act_layer
83+ norm_layer = norm_layer ,
84+ act_layer = act_layer
9685 )
86+
87+ def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
88+ B , C , H , W = x .shape
89+ # [B, C, H, W] -> [B, H*W, C]
90+ q = self .conv_q (x ).flatten (2 ).transpose (1 , 2 )
91+ k = self .conv_k (x ).flatten (2 ).transpose (1 , 2 )
92+ v = self .conv_v (x ).flatten (2 ).transpose (1 , 2 )
93+ return q , k , v
94+
95+ class Attention (nn .Module ):
96+ fused_attn : Final [bool ]
97+
98+ def __init__ (
99+ self ,
100+ dim : int ,
101+ num_heads : int = 8 ,
102+ qkv_bias : bool = True ,
103+ qk_norm : bool = False ,
104+ attn_drop : float = 0. ,
105+ proj_drop : float = 0. ,
106+ norm_layer : nn .Module = nn .LayerNorm ,
107+ ) -> None :
108+ super ().__init__ ()
109+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
110+ self .num_heads = num_heads
111+ self .head_dim = dim // num_heads
112+ self .scale = self .head_dim ** - 0.5
113+ self .fused_attn = use_fused_attn ()
97114
98- # FIXME better way to do this? iirc 1 is better than 3
99- self .proj_q = nn .Linear (in_chs , out_chs , bias = qkv_bias )
100- self .proj_k = nn .Linear (in_chs , out_chs , bias = qkv_bias )
101- self .proj_v = nn .Linear (in_chs , out_chs , bias = qkv_bias )
115+ self .proj_q = nn .Linear (dim , dim , bias = qkv_bias )
116+ self .proj_k = nn .Linear (dim , dim , bias = qkv_bias )
117+ self .proj_v = nn .Linear (dim , dim , bias = qkv_bias )
118+ self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
119+ self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
102120 self .attn_drop = nn .Dropout (attn_drop )
103121 self .proj = nn .Linear (out_chs , out_chs )
104122 self .proj_drop = nn .Dropout (proj_drop )
105123
106- def forward (self , x : Tensor ):
107- # [B, C_in, H, W] -> [B, H*W, C_out]
108- q = self .conv_q (x ).flatten (2 ).transpose (1 , 2 )
109- k = self .conv_k (x ).flatten (2 ).transpose (1 , 2 )
110- v = self .conv_v (x ).flatten (2 ).transpose (1 , 2 )
124+ def forward (self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ) -> torch .Tensor :
125+ B , N , C = q .shape
111126
112- # need to handle cls token here
113-
114- # [B, H*W, C_out] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h]
127+ # [B, H*W, C] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h]
115128 q = self .proj_q (q ).reshape (B , q .shape [2 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
116129 k = self .proj_k (k ).reshape (B , k .shape [2 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
117130 v = self .proj_v (v ).reshape (B , v .shape [2 ], self .num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 )
118-
119- # FIXME F.sdpa
120- q = q * self .scale
121- attn = q @ k .transpose (- 2 , - 1 )
122- attn = attn .softmax (dim = - 1 )
123- attn = self .attn_drop (attn )
124- x = attn @ v
125-
126- x = x .transpose (1 , 2 ).reshape (B , N , self .out_chs )
131+ q , k = self .q_norm (q ), self .k_norm (k )
132+ # [B, n_h, H*W, d_h], [B, n_h, H*W/4, d_h], [B, n_h, H*W/4, d_h]
133+
134+ if self .fused_attn :
135+ x = F .scaled_dot_product_attention (
136+ q , k , v ,
137+ dropout_p = self .attn_drop .p if self .training else 0. ,
138+ )
139+ else :
140+ q = q * self .scale
141+ attn = q @ k .transpose (- 2 , - 1 )
142+ attn = attn .softmax (dim = - 1 )
143+ attn = self .attn_drop (attn )
144+ x = attn @ v
145+
146+ x = x .transpose (1 , 2 ).reshape (B , N , C )
127147 x = self .proj (x )
128148 x = self .proj_drop (x )
129-
130149 return x
131150
132- class QuickGELU (nn .Module ):
133- def forward (self , x : Tensor ):
134- return x * torch .sigmoid (1.702 * x )
151+ class CvTBlock (nn .Module ):
152+ def __init__ (
153+ self ,
154+ dim : int ,
155+ kernel_size : int = 3 ,
156+ stride_q : int = 1 ,
157+ stride_kv : int = 2 ,
158+ padding : int = 1 ,
159+ conv_bias : bool = False ,
160+ conv_norm_layer : nn .Module = nn .BatchNorm2d ,
161+ conv_act_layer : nn .Module = nn .Identity (),
162+ num_heads : int = 8 ,
163+ qkv_bias : bool = True ,
164+ qk_norm : bool = False ,
165+ attn_drop : float = 0. ,
166+ proj_drop : float = 0. ,
167+ input_norm_layer = LayerNorm2d ,
168+ norm_layer : nn .Module = nn .LayerNorm ,
169+ init_values : Optional [float ] = None ,
170+ drop_path : float = 0. ,
171+ mlp_layer : nn .Module = Mlp ,
172+ mlp_ratio : float = 4. ,
173+ mlp_act_layer : nn .Module = QuickGELU ,
174+ use_cls_token : bool = False ,
175+ ) -> None :
176+ self .use_cls_token = use_cls_token
177+
178+ self .norm1 = norm_layer (dim )
179+ self .conv_proj = ConvProj (
180+ dim = dim ,
181+ kernel_size = kernel_size ,
182+ stride_q = stride_q ,
183+ stride_kv = stride_kv ,
184+ padding = padding ,
185+ bias = conv_bias ,
186+ norm_layer = conv_norm_layer ,
187+ act_layer = conv_act_layer ,
188+ )
189+ self .attn = Attention (
190+ dim = dim ,
191+ num_heads = num_heads ,
192+ qkv_bias = qkv_bias ,
193+ qk_norm = qk_norm ,
194+ attn_drop = attn_drop ,
195+ proj_drop = proj_drop ,
196+ norm_layer = norm_layer
197+ )
198+ self .ls1 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
199+ self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
200+
201+ self .norm2 = norm_layer (dim )
202+ self .mlp = mlp_layer (
203+ in_features = dim ,
204+ hidden_features = int (dim * mlp_ratio ),
205+ act_layer = act_layer ,
206+ drop = proj_drop ,
207+ )
208+ self .ls2 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
209+ self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
210+
211+ def add_cls_token (
212+ self ,
213+ q : torch .Tensor ,
214+ k : torch .Tensor ,
215+ v : torch .Tensor ,
216+ cls_token : Optional [torch .Tensor ]
217+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
218+ if self .use_cls_token :
219+ q = torch .cat ((cls_token , q ), dim = 1 )
220+ k = torch .cat ((cls_token , k ), dim = 1 )
221+ v = torch .cat ((cls_token , v ), dim = 1 )
222+ return q , k , v
223+
224+ def fw_attn (self , x : torch .Tensor , cls_token : Optional [torch .Tensor ]) -> torch .Tensor :
225+ return self .attn (* self .add_cls_token (* self .conv_proj (x ), cls_token ))
226+
227+ def forward (self , x : torch .Tensor , cls_token : Optional [torch .Tensor ]) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
228+ B , C , H , W = x .shape
229+
230+ x = x .flatten (2 ).transpose (1 , 2 ) + self .drop_path1 (self .ls1 (self .fw (attn (self .norm1 (x )))))
231+ x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
232+
233+ if self .use_cls_token :
234+ cls_token , x = torch .split (x , [1 , H * W ], 1 )
235+
236+ return x , cls_token
237+
238+ class CvTStage (nn .Module ):
239+ def __init__ (
240+ in_chs : int ,
241+ dim : int ,
242+ depth : int ,
243+ embed_kernel_size : int = 7 ,
244+ embed_stride : int = 4 ,
245+ embed_padding : int 2 ,
246+ kernel_size : int = 3 ,
247+ stride_q : int = 1 ,
248+ stride_kv : int = 2 ,
249+ padding : int = 1 ,
250+ conv_bias : bool = False ,
251+ conv_norm_layer : nn .Module = nn .BatchNorm2d ,
252+ conv_act_layer : nn .Module = nn .Identity (),
253+ num_heads : int = 8 ,
254+ qkv_bias : bool = True ,
255+ qk_norm : bool = False ,
256+ attn_drop : float = 0. ,
257+ proj_drop : float = 0. ,
258+ input_norm_layer = LayerNorm2d ,
259+ norm_layer : nn .Module = nn .LayerNorm ,
260+ init_values : Optional [float ] = None ,
261+ drop_path : float = 0. ,
262+ mlp_layer : nn .Module = Mlp ,
263+ mlp_ratio : float = 4. ,
264+ mlp_act_layer : nn .Module = QuickGELU ,
265+ use_cls_token : bool = False ,
266+ ) -> None :
267+ self .conv_embed = ConvEmbed (
268+ in_chs = in_chs ,
269+ out_chs = dim ,
270+ kernel_size = embed_kernel_size ,
271+ stride = embed_stride ,
272+ padding = embed_padding ,
273+ norm_layer = input_norm_layer ,
274+ )
275+ self .embed_drop = nn .Dropout (proj_drop )
276+
277+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , dim )) if use_cls_token else None
278+
279+ blocks = []
280+ for i in range (depth ):
281+ blocks .append (
282+ CvTBlock (
283+ dim = dim ,
284+ kernel_size = kernel_size ,
285+ stride_q = stride_q ,
286+ stride_kv = stride_kv ,
287+ padding = padding ,
288+ conv_bias = conv_bias ,
289+ conv_norm_layer = conv_norm_layer ,
290+ conv_act_layer = conv_act_layer ,
291+ num_heads = num_heads ,
292+ qkv_bias = qkv_bias ,
293+ qk_norm = qk_norm ,
294+ attn_drop = attn_drop ,
295+ proj_drop = proj_drop ,
296+ input_norm_layer input_norm_layer ,
297+ norm_layer = norm_layer ,
298+ init_values = init_values ,
299+ drop_path = drop_path ,
300+ mlp_layer = mlp_layer ,
301+ mlp_ratio = mlp_ratio ,
302+ mlp_act_layer = mlp_act_layer ,
303+ use_cls_token = use_cls_token ,
304+ )
305+ )
306+ self .blocks = nn .ModuleList (blocks )
307+
308+ if self .cls_token is not None :
309+ trunc_normal_ (self .cls_token , std = .02 )
135310
136-
311+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
312+ x = self .conv_embed (x )
313+ x = self .embed_drop (x )
314+
315+ cls_token = self .cls_token
316+ for block in self .blocks :
317+ x , cls_token = block (x , cls_token )
318+
319+ return x , cls_token
320+
321+ class CvT (nn .Module ):
322+
0 commit comments