Skip to content

Commit 0d171c6

Browse files
committed
Update cvt.py
1 parent 95b6a52 commit 0d171c6

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

timm/models/cvt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T
229229
def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
230230
B, C, H, W = x.shape
231231

232-
x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw(attn(self.norm1(x)))))
232+
x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw_attn(self.norm1(x))))
233233
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
234234

235235
if self.use_cls_token:
@@ -395,4 +395,6 @@ def __init__(
395395
use_cls_token = use_cls_token[stage_idx],
396396
)
397397
in_chs = dim
398-
stages.append(stage)
398+
stages.append(stage)
399+
self.stages = nn.ModuleList(stages)
400+

0 commit comments

Comments
 (0)