Skip to content

Commit 9b020d4

Browse files
committed
Update cvt.py
1 parent 4c2827f commit 9b020d4

File tree

1 file changed

+264
-78
lines changed

1 file changed

+264
-78
lines changed

timm/models/cvt.py

Lines changed: 264 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
from typing import Optional, Tuple
2+
13
import torch
24
import torch.nn
3-
from torch import Tensor
5+
import torch.nn.functional as F
6+
7+
from timm.layers import ConvNormAct, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn
48

5-
from timm.layers import LayerNorm2d, Mlp, ConvNormAct
69

710
class ConvEmbed(nn.Module):
811
def __init__(
912
self,
10-
in_chs=3,
11-
out_chs=64,
12-
kernel_size=7,
13-
stride=4,
14-
padding=2,
15-
norm_layer=LayerNorm2d,
16-
):
13+
in_chs: int = 3,
14+
out_chs: int = 64,
15+
kernel_size: int = 7,
16+
stride: int = 4,
17+
padding: int = 2,
18+
norm_layer: nn.Module = nn.LayerNorm2d,
19+
) -> None:
1720
super().__init__()
1821

1922
self.conv = nn.Conv2d(
@@ -26,111 +29,294 @@ def __init__(
2629

2730
self.norm = norm_layer(out_chs) if norm_layer else nn.Identity()
2831

29-
def forward(self, x: Tensor): # [B, C, H, W] -> [B, C, H, W]
32+
def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, C, H, W] -> [B, C, H, W]
3033
x = self.conv(x)
3134
x = self.norm(x)
3235
return x
3336

34-
35-
36-
class Attention(nn.Module):
37+
class ConvProj(nn.Module):
3738
def __init__(
3839
self,
39-
in_chs,
40-
out_chs,
41-
num_heads,
42-
kernel_size=3,
43-
stride_q=1,
44-
stride_kv=1,
45-
padding_q=1,
46-
padding_kv=1,
47-
qkv_bias=False,
48-
conv_bias=False,
49-
attn_drop=0.,
50-
proj_drop=0.,
51-
conv_norm_layer=nn.BatchNorm2d,
52-
conv_act_layer=nn.Identity(),
53-
54-
cls_token=True
55-
):
56-
assert out_chs % num_heads == 0, 'dim should be divisible by num_heads'
57-
self.out_chs = out_chs
58-
self.num_heads = num_heads
59-
self.head_dim = dim // num_heads
60-
self.scale = out_chs ** -0.5
61-
40+
dim: int,
41+
kernel_size: int = 3,
42+
stride_q: int = 1,
43+
stride_kv: int = 2,
44+
padding: int = 1,
45+
bias: bool = False,
46+
norm_layer: nn.Module = nn.BatchNorm2d,
47+
act_layer: nn.Module = nn.Identity(),
48+
) -> None:
49+
self.dim = dim
50+
6251
self.conv_q = ConvNormAct(
63-
in_chs,
64-
out_chs,
52+
dim,
53+
dim,
6554
kernel_size,
6655
stride=stride_q,
67-
padding=padding_q,
56+
padding=padding,
6857
groups=in_chs,
69-
bias=conv_bias,
70-
norm_layer=conv_norm_layer,
71-
act_layer=conv_act_layer
58+
bias=bias,
59+
norm_layer=norm_layer,
60+
act_layer=act_layer
7261
)
7362

7463
self.conv_k = ConvNormAct(
75-
in_chs,
76-
out_chs * 2,
64+
dim,
65+
dim,
7766
kernel_size,
7867
stride=stride_kv,
79-
padding=padding_kv,
68+
padding=padding,
8069
groups=in_chs,
8170
bias=conv_bias,
82-
norm_layer=conv_norm_layer,
83-
act_layer=conv_act_layer
71+
norm_layer=norm_layer,
72+
act_layer=act_layer
8473
)
8574

8675
self.conv_v = ConvNormAct(
87-
in_chs,
88-
out_chs * 2,
76+
dim,
77+
dim,
8978
kernel_size,
9079
stride=stride_kv,
91-
padding=padding_kv,
80+
padding=padding,
9281
groups=in_chs,
9382
bias=conv_bias,
94-
norm_layer=conv_norm_layer,
95-
act_layer=conv_act_layer
83+
norm_layer=norm_layer,
84+
act_layer=act_layer
9685
)
86+
87+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
88+
B, C, H, W = x.shape
89+
# [B, C, H, W] -> [B, H*W, C]
90+
q = self.conv_q(x).flatten(2).transpose(1, 2)
91+
k = self.conv_k(x).flatten(2).transpose(1, 2)
92+
v = self.conv_v(x).flatten(2).transpose(1, 2)
93+
return q, k, v
94+
95+
class Attention(nn.Module):
96+
fused_attn: Final[bool]
97+
98+
def __init__(
99+
self,
100+
dim: int,
101+
num_heads: int = 8,
102+
qkv_bias: bool = True,
103+
qk_norm: bool = False,
104+
attn_drop: float = 0.,
105+
proj_drop: float = 0.,
106+
norm_layer: nn.Module = nn.LayerNorm,
107+
) -> None:
108+
super().__init__()
109+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
110+
self.num_heads = num_heads
111+
self.head_dim = dim // num_heads
112+
self.scale = self.head_dim ** -0.5
113+
self.fused_attn = use_fused_attn()
97114

98-
# FIXME better way to do this? iirc 1 is better than 3
99-
self.proj_q = nn.Linear(in_chs, out_chs, bias=qkv_bias)
100-
self.proj_k = nn.Linear(in_chs, out_chs, bias=qkv_bias)
101-
self.proj_v = nn.Linear(in_chs, out_chs, bias=qkv_bias)
115+
self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
116+
self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
117+
self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
118+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
119+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
102120
self.attn_drop = nn.Dropout(attn_drop)
103121
self.proj = nn.Linear(out_chs, out_chs)
104122
self.proj_drop = nn.Dropout(proj_drop)
105123

106-
def forward(self, x: Tensor):
107-
# [B, C_in, H, W] -> [B, H*W, C_out]
108-
q = self.conv_q(x).flatten(2).transpose(1, 2)
109-
k = self.conv_k(x).flatten(2).transpose(1, 2)
110-
v = self.conv_v(x).flatten(2).transpose(1, 2)
124+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
125+
B, N, C = q.shape
111126

112-
# need to handle cls token here
113-
114-
# [B, H*W, C_out] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h]
127+
# [B, H*W, C] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h]
115128
q = self.proj_q(q).reshape(B, q.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
116129
k = self.proj_k(k).reshape(B, k.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
117130
v = self.proj_v(v).reshape(B, v.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
118-
119-
# FIXME F.sdpa
120-
q = q * self.scale
121-
attn = q @ k.transpose(-2, -1)
122-
attn = attn.softmax(dim=-1)
123-
attn = self.attn_drop(attn)
124-
x = attn @ v
125-
126-
x = x.transpose(1, 2).reshape(B, N, self.out_chs)
131+
q, k = self.q_norm(q), self.k_norm(k)
132+
# [B, n_h, H*W, d_h], [B, n_h, H*W/4, d_h], [B, n_h, H*W/4, d_h]
133+
134+
if self.fused_attn:
135+
x = F.scaled_dot_product_attention(
136+
q, k, v,
137+
dropout_p=self.attn_drop.p if self.training else 0.,
138+
)
139+
else:
140+
q = q * self.scale
141+
attn = q @ k.transpose(-2, -1)
142+
attn = attn.softmax(dim=-1)
143+
attn = self.attn_drop(attn)
144+
x = attn @ v
145+
146+
x = x.transpose(1, 2).reshape(B, N, C)
127147
x = self.proj(x)
128148
x = self.proj_drop(x)
129-
130149
return x
131150

132-
class QuickGELU(nn.Module):
133-
def forward(self, x: Tensor):
134-
return x * torch.sigmoid(1.702 * x)
151+
class CvTBlock(nn.Module):
152+
def __init__(
153+
self,
154+
dim: int,
155+
kernel_size: int = 3,
156+
stride_q: int = 1,
157+
stride_kv: int = 2,
158+
padding: int = 1,
159+
conv_bias: bool = False,
160+
conv_norm_layer: nn.Module = nn.BatchNorm2d,
161+
conv_act_layer: nn.Module = nn.Identity(),
162+
num_heads: int = 8,
163+
qkv_bias: bool = True,
164+
qk_norm: bool = False,
165+
attn_drop: float = 0.,
166+
proj_drop: float = 0.,
167+
input_norm_layer = LayerNorm2d,
168+
norm_layer: nn.Module = nn.LayerNorm,
169+
init_values: Optional[float] = None,
170+
drop_path: float = 0.,
171+
mlp_layer: nn.Module = Mlp,
172+
mlp_ratio: float = 4.,
173+
mlp_act_layer: nn.Module = QuickGELU,
174+
use_cls_token: bool = False,
175+
) -> None:
176+
self.use_cls_token = use_cls_token
177+
178+
self.norm1 = norm_layer(dim)
179+
self.conv_proj = ConvProj(
180+
dim = dim,
181+
kernel_size = kernel_size,
182+
stride_q = stride_q,
183+
stride_kv = stride_kv,
184+
padding = padding,
185+
bias = conv_bias,
186+
norm_layer = conv_norm_layer,
187+
act_layer = conv_act_layer,
188+
)
189+
self.attn = Attention(
190+
dim = dim,
191+
num_heads = num_heads,
192+
qkv_bias = qkv_bias,
193+
qk_norm = qk_norm,
194+
attn_drop = attn_drop,
195+
proj_drop = proj_drop,
196+
norm_layer = norm_layer
197+
)
198+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
199+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
200+
201+
self.norm2 = norm_layer(dim)
202+
self.mlp = mlp_layer(
203+
in_features=dim,
204+
hidden_features=int(dim * mlp_ratio),
205+
act_layer=act_layer,
206+
drop=proj_drop,
207+
)
208+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
209+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
210+
211+
def add_cls_token(
212+
self,
213+
q: torch.Tensor,
214+
k: torch.Tensor,
215+
v: torch.Tensor,
216+
cls_token: Optional[torch.Tensor]
217+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
218+
if self.use_cls_token:
219+
q = torch.cat((cls_token, q), dim=1)
220+
k = torch.cat((cls_token, k), dim=1)
221+
v = torch.cat((cls_token, v), dim=1)
222+
return q, k, v
223+
224+
def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.Tensor:
225+
return self.attn(*self.add_cls_token(*self.conv_proj(x), cls_token))
226+
227+
def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
228+
B, C, H, W = x.shape
229+
230+
x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw(attn(self.norm1(x)))))
231+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
232+
233+
if self.use_cls_token:
234+
cls_token, x = torch.split(x, [1, H*W], 1)
235+
236+
return x, cls_token
237+
238+
class CvTStage(nn.Module):
239+
def __init__(
240+
in_chs: int,
241+
dim: int,
242+
depth: int,
243+
embed_kernel_size: int = 7,
244+
embed_stride: int = 4,
245+
embed_padding: int 2,
246+
kernel_size: int = 3,
247+
stride_q: int = 1,
248+
stride_kv: int = 2,
249+
padding: int = 1,
250+
conv_bias: bool = False,
251+
conv_norm_layer: nn.Module = nn.BatchNorm2d,
252+
conv_act_layer: nn.Module = nn.Identity(),
253+
num_heads: int = 8,
254+
qkv_bias: bool = True,
255+
qk_norm: bool = False,
256+
attn_drop: float = 0.,
257+
proj_drop: float = 0.,
258+
input_norm_layer = LayerNorm2d,
259+
norm_layer: nn.Module = nn.LayerNorm,
260+
init_values: Optional[float] = None,
261+
drop_path: float = 0.,
262+
mlp_layer: nn.Module = Mlp,
263+
mlp_ratio: float = 4.,
264+
mlp_act_layer: nn.Module = QuickGELU,
265+
use_cls_token: bool = False,
266+
) -> None:
267+
self.conv_embed = ConvEmbed(
268+
in_chs = in_chs,
269+
out_chs = dim,
270+
kernel_size = embed_kernel_size,
271+
stride = embed_stride,
272+
padding = embed_padding,
273+
norm_layer = input_norm_layer,
274+
)
275+
self.embed_drop = nn.Dropout(proj_drop)
276+
277+
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) if use_cls_token else None
278+
279+
blocks = []
280+
for i in range(depth):
281+
blocks.append(
282+
CvTBlock(
283+
dim = dim,
284+
kernel_size = kernel_size,
285+
stride_q = stride_q,
286+
stride_kv = stride_kv,
287+
padding = padding,
288+
conv_bias = conv_bias,
289+
conv_norm_layer = conv_norm_layer,
290+
conv_act_layer = conv_act_layer,
291+
num_heads = num_heads,
292+
qkv_bias = qkv_bias,
293+
qk_norm = qk_norm,
294+
attn_drop = attn_drop,
295+
proj_drop = proj_drop,
296+
input_norm_layer input_norm_layer,
297+
norm_layer = norm_layer,
298+
init_values = init_values,
299+
drop_path = drop_path,
300+
mlp_layer = mlp_layer,
301+
mlp_ratio = mlp_ratio,
302+
mlp_act_layer = mlp_act_layer,
303+
use_cls_token = use_cls_token,
304+
)
305+
)
306+
self.blocks = nn.ModuleList(blocks)
307+
308+
if self.cls_token is not None:
309+
trunc_normal_(self.cls_token, std=.02)
135310

136-
311+
def forward(self, x: torch.Tensor) -> torch.Tensor:
312+
x = self.conv_embed(x)
313+
x = self.embed_drop(x)
314+
315+
cls_token = self.cls_token
316+
for block in self.blocks:
317+
x, cls_token = block(x, cls_token)
318+
319+
return x, cls_token
320+
321+
class CvT(nn.Module):
322+

0 commit comments

Comments
 (0)