@@ -61,6 +61,9 @@ def __init__(
6161 act_layer = act_layer
6262 )
6363
64+ # TODO fuse kv conv?
65+ # TODO if act_layer is id and not cls_token (gap model?), is later projection in attn necessary?
66+
6467 self .conv_k = ConvNormAct (
6568 dim ,
6669 dim ,
@@ -235,6 +238,8 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t
235238 if self .use_cls_token :
236239 cls_token , x = torch .split (x , [1 , H * W ], 1 )
237240
241+ x = x .transpose (1 , 2 ).reshape (B , C , H , W )
242+
238243 return x , cls_token
239244
240245class CvTStage (nn .Module ):
@@ -359,6 +364,9 @@ def __init__(
359364 self .num_features = dims [- 1 ]
360365 self .drop_rate = drop_rate
361366
367+ # FIXME only on last stage, no need for tuple
368+ self .use_cls_token = use_cls_token [- 1 ]
369+
362370 dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
363371
364372 in_chs = in_chans
@@ -397,4 +405,19 @@ def __init__(
397405 in_chs = dim
398406 stages .append (stage )
399407 self .stages = nn .ModuleList (stages )
408+
409+ self .head_norm = norm_layer (dims [- 1 ])
410+ self .head = nn .Linear (dims [- 1 ], num_classes ) if num_classes > 0 else nn .Identity ()
411+
412+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
413+
414+ for stage in self .stages :
415+ x , cls_token = stage (x )
416+
417+
418+ if self .use_cls_token :
419+ return self .head (self .head_norm (cls_token ))
420+ else :
421+ return self .head (self .head_norm (x .mean (dim = (2 ,3 ))))
422+
400423
0 commit comments