Skip to content

Commit 8f7627c

Browse files
committed
wip
1 parent 11cc4a7 commit 8f7627c

File tree

2 files changed

+107
-31
lines changed

2 files changed

+107
-31
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .convnext import *
99
from .crossvit import *
1010
from .cspnet import *
11+
from .cvt import *
1112
from .davit import *
1213
from .deit import *
1314
from .densenet import *

timm/models/cvt.py

Lines changed: 106 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,28 @@
1-
from typing import Optional, Tuple
1+
""" CvT: Convolutional Vision Transformer
2+
3+
From-scratch implementation of CvT in PyTorch
4+
5+
'CvT: Introducing Convolutions to Vision Transformers'
6+
- https://arxiv.org/abs/2103.15808
7+
8+
Implementation for timm by / Copyright 2024, Fredo Guan
9+
"""
10+
11+
from functools import partial
12+
from typing import List, Final, Optional, Tuple
213

314
import torch
4-
import torch.nn
15+
import torch.nn as nn
516
import torch.nn.functional as F
617

7-
from timm.layers import ConvNormAct, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn
18+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19+
from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn
20+
from ._builder import build_model_with_cfg
21+
from ._registry import generate_default_cfgs, register_model
22+
23+
824

25+
__all__ = ['CvT']
926

1027
class ConvEmbed(nn.Module):
1128
def __init__(
@@ -15,7 +32,7 @@ def __init__(
1532
kernel_size: int = 7,
1633
stride: int = 4,
1734
padding: int = 2,
18-
norm_layer: nn.Module = nn.LayerNorm2d,
35+
norm_layer: nn.Module = LayerNorm2d,
1936
) -> None:
2037
super().__init__()
2138

@@ -44,7 +61,7 @@ def __init__(
4461
padding: int = 1,
4562
bias: bool = False,
4663
norm_layer: nn.Module = nn.BatchNorm2d,
47-
act_layer: nn.Module = nn.Identity(),
64+
act_layer: nn.Module = nn.Identity,
4865
) -> None:
4966
super().__init__()
5067
self.dim = dim
@@ -55,7 +72,7 @@ def __init__(
5572
kernel_size,
5673
stride=stride_q,
5774
padding=padding,
58-
groups=in_chs,
75+
groups=dim,
5976
bias=bias,
6077
norm_layer=norm_layer,
6178
act_layer=act_layer
@@ -70,8 +87,8 @@ def __init__(
7087
kernel_size,
7188
stride=stride_kv,
7289
padding=padding,
73-
groups=in_chs,
74-
bias=conv_bias,
90+
groups=dim,
91+
bias=bias,
7592
norm_layer=norm_layer,
7693
act_layer=act_layer
7794
)
@@ -82,8 +99,8 @@ def __init__(
8299
kernel_size,
83100
stride=stride_kv,
84101
padding=padding,
85-
groups=in_chs,
86-
bias=conv_bias,
102+
groups=dim,
103+
bias=bias,
87104
norm_layer=norm_layer,
88105
act_layer=act_layer
89106
)
@@ -107,7 +124,7 @@ def __init__(
107124
qk_norm: bool = False,
108125
attn_drop: float = 0.,
109126
proj_drop: float = 0.,
110-
norm_layer: nn.Module = nn.LayerNorm,
127+
norm_layer: nn.Module = LayerNorm,
111128
) -> None:
112129
super().__init__()
113130
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@@ -122,16 +139,16 @@ def __init__(
122139
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
123140
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
124141
self.attn_drop = nn.Dropout(attn_drop)
125-
self.proj = nn.Linear(out_chs, out_chs)
142+
self.proj = nn.Linear(dim, dim)
126143
self.proj_drop = nn.Dropout(proj_drop)
127144

128145
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
129146
B, N, C = q.shape
130147

131148
# [B, H*W, C] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h]
132-
q = self.proj_q(q).reshape(B, q.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
133-
k = self.proj_k(k).reshape(B, k.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
134-
v = self.proj_v(v).reshape(B, v.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
149+
q = self.proj_q(q).reshape(B, q.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
150+
k = self.proj_k(k).reshape(B, k.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
151+
v = self.proj_v(v).reshape(B, v.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
135152
q, k = self.q_norm(q), self.k_norm(k)
136153
# [B, n_h, H*W, d_h], [B, n_h, H*W/4, d_h], [B, n_h, H*W/4, d_h]
137154

@@ -162,14 +179,14 @@ def __init__(
162179
padding: int = 1,
163180
conv_bias: bool = False,
164181
conv_norm_layer: nn.Module = nn.BatchNorm2d,
165-
conv_act_layer: nn.Module = nn.Identity(),
182+
conv_act_layer: nn.Module = nn.Identity,
166183
num_heads: int = 1,
167184
qkv_bias: bool = True,
168185
qk_norm: bool = False,
169186
attn_drop: float = 0.,
170187
proj_drop: float = 0.,
171-
input_norm_layer = LayerNorm2d,
172-
norm_layer: nn.Module = nn.LayerNorm,
188+
input_norm_layer: nn.Module = partial(LayerNorm2d, eps=1e-5),
189+
norm_layer: nn.Module = partial(LayerNorm, eps=1e-5),
173190
init_values: Optional[float] = None,
174191
drop_path: float = 0.,
175192
mlp_layer: nn.Module = Mlp,
@@ -180,7 +197,7 @@ def __init__(
180197
super().__init__()
181198
self.use_cls_token = use_cls_token
182199

183-
self.norm1 = norm_layer(dim)
200+
self.norm1 = input_norm_layer(dim)
184201
self.conv_proj = ConvProj(
185202
dim = dim,
186203
kernel_size = kernel_size,
@@ -207,7 +224,7 @@ def __init__(
207224
self.mlp = mlp_layer(
208225
in_features=dim,
209226
hidden_features=int(dim * mlp_ratio),
210-
act_layer=act_layer,
227+
act_layer=mlp_act_layer,
211228
drop=proj_drop,
212229
)
213230
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
@@ -232,7 +249,8 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T
232249
def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
233250
B, C, H, W = x.shape
234251

235-
x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw_attn(self.norm1(x))))
252+
x = torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2) \
253+
+ self.drop_path1(self.ls1(self.fw_attn(self.norm1(x), cls_token)))
236254
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
237255

238256
if self.use_cls_token:
@@ -244,6 +262,7 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t
244262

245263
class CvTStage(nn.Module):
246264
def __init__(
265+
self,
247266
in_chs: int,
248267
dim: int,
249268
depth: int,
@@ -256,14 +275,14 @@ def __init__(
256275
padding: int = 1,
257276
conv_bias: bool = False,
258277
conv_norm_layer: nn.Module = nn.BatchNorm2d,
259-
conv_act_layer: nn.Module = nn.Identity(),
278+
conv_act_layer: nn.Module = nn.Identity,
260279
num_heads: int = 1,
261280
qkv_bias: bool = True,
262281
qk_norm: bool = False,
263282
attn_drop: float = 0.,
264283
proj_drop: float = 0.,
265284
input_norm_layer = LayerNorm2d,
266-
norm_layer: nn.Module = nn.LayerNorm,
285+
norm_layer: nn.Module = LayerNorm,
267286
init_values: Optional[float] = None,
268287
drop_path_rates: List[float] = [0.],
269288
mlp_layer: nn.Module = Mlp,
@@ -320,14 +339,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
320339
x = self.conv_embed(x)
321340
x = self.embed_drop(x)
322341

323-
cls_token = self.cls_token
342+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.cls_token is not None else None
324343
for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor
325344
x, cls_token = block(x, cls_token)
326345

327346
return x, cls_token
328347

329348
class CvT(nn.Module):
330349
def __init__(
350+
self,
331351
in_chans: int = 3,
332352
num_classes: int = 1000,
333353
dims: Tuple[int, ...] = (64, 192, 384),
@@ -341,14 +361,14 @@ def __init__(
341361
padding: int = 1,
342362
conv_bias: bool = False,
343363
conv_norm_layer: nn.Module = nn.BatchNorm2d,
344-
conv_act_layer: nn.Module = nn.Identity(),
364+
conv_act_layer: nn.Module = nn.Identity,
345365
num_heads: Tuple[int, ...] = (1, 3, 6),
346366
qkv_bias: bool = True,
347367
qk_norm: bool = False,
348368
attn_drop: float = 0.,
349369
proj_drop: float = 0.,
350370
input_norm_layer = LayerNorm2d,
351-
norm_layer: nn.Module = nn.LayerNorm,
371+
norm_layer: nn.Module = LayerNorm,
352372
init_values: Optional[float] = None,
353373
drop_path_rate: float = 0.,
354374
mlp_layer: nn.Module = Mlp,
@@ -362,7 +382,6 @@ def __init__(
362382
assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token)
363383
self.num_classes = num_classes
364384
self.num_features = dims[-1]
365-
self.drop_rate = drop_rate
366385

367386
# FIXME only on last stage, no need for tuple
368387
self.use_cls_token = use_cls_token[-1]
@@ -371,6 +390,8 @@ def __init__(
371390

372391
in_chs = in_chans
373392

393+
# TODO move stem
394+
374395
stages = []
375396
for stage_idx in range(num_stages):
376397
dim = dims[stage_idx]
@@ -406,7 +427,7 @@ def __init__(
406427
stages.append(stage)
407428
self.stages = nn.ModuleList(stages)
408429

409-
self.head_norm = norm_layer(dims[-1])
430+
self.norm = norm_layer(dims[-1])
410431
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
411432

412433
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -416,8 +437,62 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
416437

417438

418439
if self.use_cls_token:
419-
return self.head(self.head_norm(cls_token))
440+
return self.head(self.norm(cls_token.flatten(1)))
420441
else:
421-
return self.head(self.head_norm(x.mean(dim=(2,3))))
442+
return self.head(self.norm(x.mean(dim=(2,3))))
422443

423-
444+
445+
446+
def checkpoint_filter_fn(state_dict, model):
447+
""" Remap MSFT checkpoints -> timm """
448+
if 'head.fc.weight' in state_dict:
449+
return state_dict # non-MSFT checkpoint
450+
451+
if 'state_dict' in state_dict:
452+
state_dict = state_dict['state_dict']
453+
454+
import re
455+
out_dict = {}
456+
for k, v in state_dict.items():
457+
k = re.sub(r'stage([0-9]+)', r'stages.\1', k)
458+
k = k.replace('patch_embed', 'conv_embed')
459+
k = k.replace('conv_embed.proj', 'conv_embed.conv')
460+
k = k.replace('attn.conv_proj', 'conv_proj.conv')
461+
out_dict[k] = v
462+
return out_dict
463+
464+
465+
def _create_cvt(variant, pretrained=False, **kwargs):
466+
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 2, 10))))
467+
out_indices = kwargs.pop('out_indices', default_out_indices)
468+
469+
model = build_model_with_cfg(
470+
CvT,
471+
variant,
472+
pretrained,
473+
pretrained_filter_fn=checkpoint_filter_fn,
474+
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
475+
**kwargs)
476+
477+
return model
478+
479+
# TODO update first_conv
480+
def _cfg(url='', **kwargs):
481+
return {
482+
'url': url,
483+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14),
484+
'crop_pct': 0.95, 'interpolation': 'bicubic',
485+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
486+
'first_conv': 'stem.conv', 'classifier': 'head',
487+
**kwargs
488+
}
489+
490+
default_cfgs = generate_default_cfgs({
491+
'cvt_13.msft_in1k': _cfg(url='https://files.catbox.moe/xz97kh.pth'),
492+
})
493+
494+
495+
@register_model
496+
def cvt_13(pretrained=False, **kwargs) -> CvT:
497+
model_args = dict(depths=(1, 2, 10), dims=(64, 192, 384), num_heads=(1, 3, 6))
498+
return _create_cvt('cvt_13', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)