88Implementation for timm by / Copyright 2024, Fredo Guan
99"""
1010
11+ from collections import OrderedDict
1112from functools import partial
1213from typing import List , Final , Optional , Tuple
1314
@@ -51,6 +52,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, C, H, W] -> [B, C, H,
5152 x = self .norm (x )
5253 return x
5354
55+
56+
5457class ConvProj (nn .Module ):
5558 def __init__ (
5659 self ,
@@ -65,7 +68,9 @@ def __init__(
6568 ) -> None :
6669 super ().__init__ ()
6770 self .dim = dim
68-
71+
72+ # FIXME not working, bn layer outputs are incorrect
73+ '''
6974 self.conv_q = ConvNormAct(
7075 dim,
7176 dim,
@@ -78,7 +83,7 @@ def __init__(
7883 act_layer=act_layer
7984 )
8085
81- # TODO fuse kv conv?
86+ # TODO fuse kv conv? don't wanna do weight remap
8287 # TODO if act_layer is id and not cls_token (gap model?), is later projection in attn necessary?
8388
8489 self.conv_k = ConvNormAct(
@@ -104,6 +109,40 @@ def __init__(
104109 norm_layer=norm_layer,
105110 act_layer=act_layer
106111 )
112+ '''
113+ self .conv_q = nn .Sequential (OrderedDict ([
114+ ('conv' , nn .Conv2d (
115+ dim ,
116+ dim ,
117+ kernel_size = kernel_size ,
118+ padding = padding ,
119+ stride = stride_q ,
120+ bias = bias ,
121+ groups = dim
122+ )),
123+ ('bn' , nn .BatchNorm2d (dim )),]))
124+ self .conv_k = nn .Sequential (OrderedDict ([
125+ ('conv' , nn .Conv2d (
126+ dim ,
127+ dim ,
128+ kernel_size = kernel_size ,
129+ padding = padding ,
130+ stride = stride_kv ,
131+ bias = bias ,
132+ groups = dim
133+ )),
134+ ('bn' , nn .BatchNorm2d (dim )),]))
135+ self .conv_v = nn .Sequential (OrderedDict ([
136+ ('conv' , nn .Conv2d (
137+ dim ,
138+ dim ,
139+ kernel_size = kernel_size ,
140+ padding = padding ,
141+ stride = stride_kv ,
142+ bias = bias ,
143+ groups = dim
144+ )),
145+ ('bn' , nn .BatchNorm2d (dim )),]))
107146
108147 def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
109148 B , C , H , W = x .shape
0 commit comments