Visual-Attention-Network/SegNeXt

Why there is no activation function in attention module?

RicoSuaveGuapo opened this issue · 1 comments

Thanks for your excellent work, I have a quick question about the model structure.

In

class AttentionModule(BaseModule):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(
dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(
dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
attn = self.conv3(attn)
return attn * u

we can see that there are conv, element-wise plus and product. However, there is no activation function along with these operations. In the other words, without non-linear activation, these ops can be reduced into a single matrix ops.

I understand that there is SpatialAttention module has GELU which warps attnetion module, therefore non-linearity can be provided by it.

class SpatialAttention(BaseModule):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
self.activation = nn.GELU()
self.spatial_gating_unit = AttentionModule(d_model)
self.proj_2 = nn.Conv2d(d_model, d_model, 1)

But I cannot figure out the reason of only using linear ops inside attention. Is there any good reason about this, or I am just simply missing sth in here.

Good question.

I have asked myself same question and tried to merge them into a 21 x 21 matrix. However, i can not merge them into a 21 x 1 and a 1 x 21 matrix.

Actually, merging them into a 21 x 21 matrix is more expensive than current version.