zhiyuanyou/SAFECount

Query - Implementaion of weighted feature aggregation

Closed this issue · 2 comments

Hi,

I am trying to understand the implementation of your equation 5, 6 and 7 in your paper, and would be thankful if you can help.

image
image

According to your paper, the support feature fS is flipped before being used to convolve the R. However, there is a 1x1 convolution (line 247 below) applied on the support feature fS before implementing the equation 5. Are you using the projected support feature fS for the equation 5?

feats = 0
for idx, value in enumerate(values):
if self.pool.type == "max":
value = F.adaptive_max_pool2d(
value, self.pool.size, return_indices=False
)
else:
value = F.adaptive_avg_pool2d(value, self.pool.size)
attn = attns[:, idx, :, :].unsqueeze(1) # [head,1,h,w]
value = self.in_conv(value)
value = value.contiguous().view(
self.head, self.head_dim, h_p, w_p
) # [head,c,h,w]
feat_list = []
for w, v in zip(attn, value):
feat = F.conv2d(
F.pad(w.unsqueeze(0), pad), v.unsqueeze(1).flip(2, 3)
) # [1,c,h,w]
feat_list.append(feat)
feat = torch.cat(feat_list, dim=0) # [head,c,h,w]
feats += feat
assert list(feats.size()) == [self.head, self.head_dim, h_q, w_q]
feats = feats.contiguous().view(1, self.embed_dim, h_q, w_q) # [1,c,h,w]
feats = self.out_conv(feats)
return feats

Additionally, there is another 1x1 convolution (line 261) after implementing the equation 6. Are you using the projected fR to implement your equation 7?

Yes, you are right.

This implementation is modified according to nn.MultiheadAttention().

Since we want to implement this module using a multi-head form, we need to project a feature (C x H x W) to N sub-features (N x C / N x H x W) as N heads. Then we implement weighted feature aggregation within each head, finally aggregate the features of all heads together.

Thanks for your help. It makes perfect sense now.