Skip to content

Commit db586b5

Browse files
committed
Update cvt.py
1 parent 9b020d4 commit db586b5

File tree

1 file changed

+107
-31
lines changed

1 file changed

+107
-31
lines changed

timm/models/cvt.py

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
norm_layer: nn.Module = nn.BatchNorm2d,
4747
act_layer: nn.Module = nn.Identity(),
4848
) -> None:
49+
super().__init__()
4950
self.dim = dim
5051

5152
self.conv_q = ConvNormAct(
@@ -98,7 +99,7 @@ class Attention(nn.Module):
9899
def __init__(
99100
self,
100101
dim: int,
101-
num_heads: int = 8,
102+
num_heads: int = 1,
102103
qkv_bias: bool = True,
103104
qk_norm: bool = False,
104105
attn_drop: float = 0.,
@@ -159,7 +160,7 @@ def __init__(
159160
conv_bias: bool = False,
160161
conv_norm_layer: nn.Module = nn.BatchNorm2d,
161162
conv_act_layer: nn.Module = nn.Identity(),
162-
num_heads: int = 8,
163+
num_heads: int = 1,
163164
qkv_bias: bool = True,
164165
qk_norm: bool = False,
165166
attn_drop: float = 0.,
@@ -173,6 +174,7 @@ def __init__(
173174
mlp_act_layer: nn.Module = QuickGELU,
174175
use_cls_token: bool = False,
175176
) -> None:
177+
super().__init__()
176178
self.use_cls_token = use_cls_token
177179

178180
self.norm1 = norm_layer(dim)
@@ -242,28 +244,30 @@ def __init__(
242244
depth: int,
243245
embed_kernel_size: int = 7,
244246
embed_stride: int = 4,
245-
embed_padding: int 2,
247+
embed_padding: int = 2,
246248
kernel_size: int = 3,
247249
stride_q: int = 1,
248250
stride_kv: int = 2,
249251
padding: int = 1,
250252
conv_bias: bool = False,
251253
conv_norm_layer: nn.Module = nn.BatchNorm2d,
252254
conv_act_layer: nn.Module = nn.Identity(),
253-
num_heads: int = 8,
255+
num_heads: int = 1,
254256
qkv_bias: bool = True,
255257
qk_norm: bool = False,
256258
attn_drop: float = 0.,
257259
proj_drop: float = 0.,
258260
input_norm_layer = LayerNorm2d,
259261
norm_layer: nn.Module = nn.LayerNorm,
260262
init_values: Optional[float] = None,
261-
drop_path: float = 0.,
263+
drop_path_rates: List[float] = [0.],
262264
mlp_layer: nn.Module = Mlp,
263265
mlp_ratio: float = 4.,
264266
mlp_act_layer: nn.Module = QuickGELU,
265267
use_cls_token: bool = False,
266268
) -> None:
269+
super().__init__()
270+
267271
self.conv_embed = ConvEmbed(
268272
in_chs = in_chs,
269273
out_chs = dim,
@@ -278,31 +282,30 @@ def __init__(
278282

279283
blocks = []
280284
for i in range(depth):
281-
blocks.append(
282-
CvTBlock(
283-
dim = dim,
284-
kernel_size = kernel_size,
285-
stride_q = stride_q,
286-
stride_kv = stride_kv,
287-
padding = padding,
288-
conv_bias = conv_bias,
289-
conv_norm_layer = conv_norm_layer,
290-
conv_act_layer = conv_act_layer,
291-
num_heads = num_heads,
292-
qkv_bias = qkv_bias,
293-
qk_norm = qk_norm,
294-
attn_drop = attn_drop,
295-
proj_drop = proj_drop,
296-
input_norm_layer input_norm_layer,
297-
norm_layer = norm_layer,
298-
init_values = init_values,
299-
drop_path = drop_path,
300-
mlp_layer = mlp_layer,
301-
mlp_ratio = mlp_ratio,
302-
mlp_act_layer = mlp_act_layer,
303-
use_cls_token = use_cls_token,
304-
)
285+
block = CvTBlock(
286+
dim = dim,
287+
kernel_size = kernel_size,
288+
stride_q = stride_q,
289+
stride_kv = stride_kv,
290+
padding = padding,
291+
conv_bias = conv_bias,
292+
conv_norm_layer = conv_norm_layer,
293+
conv_act_layer = conv_act_layer,
294+
num_heads = num_heads,
295+
qkv_bias = qkv_bias,
296+
qk_norm = qk_norm,
297+
attn_drop = attn_drop,
298+
proj_drop = proj_drop,
299+
input_norm_layer input_norm_layer,
300+
norm_layer = norm_layer,
301+
init_values = init_values,
302+
drop_path = drop_path_rates[i],
303+
mlp_layer = mlp_layer,
304+
mlp_ratio = mlp_ratio,
305+
mlp_act_layer = mlp_act_layer,
306+
use_cls_token = use_cls_token,
305307
)
308+
blocks.append(block)
306309
self.blocks = nn.ModuleList(blocks)
307310

308311
if self.cls_token is not None:
@@ -313,10 +316,83 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
313316
x = self.embed_drop(x)
314317

315318
cls_token = self.cls_token
316-
for block in self.blocks:
319+
for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor
317320
x, cls_token = block(x, cls_token)
318321

319322
return x, cls_token
320323

321324
class CvT(nn.Module):
322-
325+
def __init__(
326+
in_chans: int = 3,
327+
num_classes: int = 1000,
328+
dims: Tuple[int, ...] = (64, 192, 384),
329+
depths: Tuple[int, ...] = (1, 2, 10),
330+
embed_kernel_size: Tuple[int, ...] = (7, 3, 3),
331+
embed_stride: Tuple[int, ...] = (4, 2, 2),
332+
embed_padding: Tuple[int, ...] = (2, 1, 1),
333+
kernel_size: int = 3,
334+
stride_q: int = 1,
335+
stride_kv: int = 2,
336+
padding: int = 1,
337+
conv_bias: bool = False,
338+
conv_norm_layer: nn.Module = nn.BatchNorm2d,
339+
conv_act_layer: nn.Module = nn.Identity(),
340+
num_heads: Tuple[int, ...] = (1, 3, 6),
341+
qkv_bias: bool = True,
342+
qk_norm: bool = False,
343+
attn_drop: float = 0.,
344+
proj_drop: float = 0.,
345+
input_norm_layer = LayerNorm2d,
346+
norm_layer: nn.Module = nn.LayerNorm,
347+
init_values: Optional[float] = None,
348+
drop_path_rate: float = 0.,
349+
mlp_layer: nn.Module = Mlp,
350+
mlp_ratio: float = 4.,
351+
mlp_act_layer: nn.Module = QuickGELU,
352+
use_cls_token: Tuple[bool, ...] = (False, False, True),
353+
) -> None:
354+
super().__init__()
355+
num_stages = len(dims)
356+
assert num_stages == len(depths) == len(embed_kernel_size) == len(embed_stride)
357+
assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token)
358+
self.num_classes = num_classes
359+
self.num_features = dims[-1]
360+
self.drop_rate = drop_rate
361+
362+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
363+
364+
in_chs = in_chans
365+
366+
stages = []
367+
for stage_idx in range(num_stages):
368+
dim = dims[stage_idx]
369+
stage = CvTStage(
370+
in_chs = in_chs,
371+
dim = dim,
372+
depth = depths[stage_idx],
373+
embed_kernel_size = embed_kernel_size[stage_idx],
374+
embed_stride = embed_stride[stage_idx],
375+
embed_padding = embed_padding[stage_idx],
376+
kernel_size = kernel_size,
377+
stride_q = stride_q,
378+
stride_kv = stride_kv,
379+
padding = padding,
380+
conv_bias = conv_bias,
381+
conv_norm_layer = conv_norm_layer,
382+
conv_act_layer = conv_act_layer,
383+
num_heads = num_heads[stage_idx],
384+
qkv_bias = qkv_bias,
385+
qk_norm = qk_norm,
386+
attn_drop = attn_drop,
387+
proj_drop = proj_drop,
388+
input_norm_layer = input_norm_layer,
389+
norm_layer = norm_layer,
390+
init_values = init_values,
391+
drop_path_rates = dpr[stage_idx],
392+
mlp_layer = mlp_layer,
393+
mlp_ratio = mlp_ratio,
394+
mlp_act_layer = mlp_act_layer,
395+
use_cls_token = use_cls_token[stage_idx],
396+
)
397+
in_chs = dim
398+
stages.append(stage)

0 commit comments

Comments
 (0)