Why 'ws 1 for stand attention' in your GroupAttention code?
kejie-cn opened this issue · 5 comments
I find that in your implementation of GroupAttention in gvt.py, you comment that 'ws 1 for stand attention'.
class GroupAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1, sr_ratio=1.0):
"""
ws 1 for stand attention
"""
super(GroupAttention, self).__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
However, I think ws means the window size, if ws=1, than the self-attention is only performed in a 1x1 window, which is not the standard self-attention.
This is an implementation choice.
We use 1 to stand for the standard attention. The resolution of the feature map in the last stage is 7x7. We perform standard self attention because it's cheap at that stage.
This is an implementation choice.
We use 1 to stand for the standard attention. The resolution of the feature map in the last stage is 7x7. We perform standard self attention because it's cheap at that stage.
but the standard self-attention should use a window size equals to to feature size (ws = 7 in the last stage)
In detailed implementation, Ws=7 does not work in the last stage. Please check the code.
@cxxgtxy In your paper, the last stage you only use GSA. for 224 classification, the last stage feature size = 7x7, ws = 7 (LSA) and ws = 1 (GSA) is equal,but for detection or segmentation, the last stage feature size maybe not 7x7, ws = 7 (LSA) and ws = 1 (GSA) is not equal, dose this mean you use LSA and GSA at the same time for the last stage?
@cxxgtxy In your paper, the last stage you only use GSA. for 224 classification, the last stage feature size = 7x7, ws = 7 (LSA) and ws = 1 (GSA) is equal,but for detection or segmentation, the last stage feature size maybe not 7x7, ws = 7 (LSA) and ws = 1 (GSA) is not equal, dose this mean you use LSA and GSA at the same time for the last stage?
It's a good question. The feature map size for the detection and segmentation task is indeed larger than 7*7. As for implementation, we use ws=1 (GSA) in the last stage (as classification).
Please see the code
python for k in range(len(depths)): _block = nn.ModuleList([block_cls( dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])]) self.blocks.append(_block) cur += depths[k]