Skip to content

Commit 11cc4a7

Browse files
committed
Update cvt.py
1 parent 0d171c6 commit 11cc4a7

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

timm/models/cvt.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def __init__(
6161
act_layer=act_layer
6262
)
6363

64+
# TODO fuse kv conv?
65+
# TODO if act_layer is id and not cls_token (gap model?), is later projection in attn necessary?
66+
6467
self.conv_k = ConvNormAct(
6568
dim,
6669
dim,
@@ -235,6 +238,8 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t
235238
if self.use_cls_token:
236239
cls_token, x = torch.split(x, [1, H*W], 1)
237240

241+
x = x.transpose(1, 2).reshape(B, C, H, W)
242+
238243
return x, cls_token
239244

240245
class CvTStage(nn.Module):
@@ -359,6 +364,9 @@ def __init__(
359364
self.num_features = dims[-1]
360365
self.drop_rate = drop_rate
361366

367+
# FIXME only on last stage, no need for tuple
368+
self.use_cls_token = use_cls_token[-1]
369+
362370
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
363371

364372
in_chs = in_chans
@@ -397,4 +405,19 @@ def __init__(
397405
in_chs = dim
398406
stages.append(stage)
399407
self.stages = nn.ModuleList(stages)
408+
409+
self.head_norm = norm_layer(dims[-1])
410+
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
411+
412+
def forward(self, x: torch.Tensor) -> torch.Tensor:
413+
414+
for stage in self.stages:
415+
x, cls_token = stage(x)
416+
417+
418+
if self.use_cls_token:
419+
return self.head(self.head_norm(cls_token))
420+
else:
421+
return self.head(self.head_norm(x.mean(dim=(2,3))))
422+
400423

0 commit comments

Comments
 (0)