Shifting attention-calculating windows
Closed this issue · 2 comments
Hello, sir. A question popped up again, unfortunately.
I've followed your shifting code, and it seems to have a difference with (my comprehension of) the paper.
I understood the behavior of the original paper's window shifting as a black arrow in the image below (self-attention is calculated with elements inside of bold lines). The left red arrow points to the result of patch-wise rolling and the right red arrow points results of rolling the entire feature map.
In my opinion, self-attention should be computed according to the right-top figure, therefore, boxes of right-bottom should be used (green dot-line separates subwindows) which each region in the right-top figure preserves.
Please let me know if I misunderstood your code or something in the paper. Thanks a lot!
Additionally, this is how I mimicked your code:
import torch
from einops import rearrange
A = torch.Tensor(list(range(1, 17))).view(1, 4, 4)
A_patched = A.view(4, 2, 2).permute(1, 2, 0).view(1, 2, 2, 4)
A_patched_rolled = torch.roll(A_patched, shifts=(-1, -1), dims=(1, 2))
A_rearranged = rearrange(A, 'a (b c) (d e)->a (b d) (c e)', b=2, d=2)
A_rearranged_rolled = torch.roll(A_rearranged, shifts=(-1, -1), dims=(1, 2))
A_rearranged_rolled2 = torch.roll(A_rearranged, shifts=(1, 1), dims=(1, 2))
where A
can be considered as a 4x4 feature map (though element order is not matched with image above), A_patched
is a divided version of A
, and A_patched_rolled
is patch-wise shifted version of A_patched
, following torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
in your code. A_rearranged
is rearranged to match the image above.
<---A_patched<---A_patched_rolled
>>> A
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]])
>>> A_patched
tensor([[[[ 1., 5., 9., 13.],
[ 2., 6., 10., 14.]],
[[ 3., 7., 11., 15.],
[ 4., 8., 12., 16.]]]])
>>> A_patched_rolled
tensor([[[[ 4., 8., 12., 16.],
[ 3., 7., 11., 15.]],
[[ 2., 6., 10., 14.],
[ 1., 5., 9., 13.]]]])
>>> A_rearranged
tensor([[[ 1., 2., 5., 6.],
[ 3., 4., 7., 8.],
[ 9., 10., 13., 14.],
[11., 12., 15., 16.]]])
>>> A_rearranged_rolled
tensor([[[ 4., 7., 8., 3.],
[10., 13., 14., 9.],
[12., 15., 16., 11.],
[ 2., 5., 6., 1.]]])
>>> A_rearranged_rolled2
tensor([[[16., 11., 12., 15.],
[ 6., 1., 2., 5.],
[ 8., 3., 4., 7.],
[14., 9., 10., 13.]]])
Thanks again for reporting potential errors in the code.
I agree with you on how we should compute the rolling and that rolling over the patches results in a wrong feature map. I think however that the code is already doing this, as in the WindowAttention we get as input the output of the PatchMerging module which is a feature map of (B x H x W x D) where B (batch size), H (height), W (width), D (hidden dimension). We then first compute the cylic_shift of this feature map (where we roll along the H, W dimensions) and only later in
q, k, v = map(lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', h=h, w_h=self.window_size, w_w=self.window_size), qkv)
we create the windows and are moving to a (B x Heads x Windows x Elements in Window x Head Dimension) setting. For the reverse shift the same as we only shift back after:
out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
which then again moves from (B x Heads x Windows x Elements in Window x Head Dimension) to the (B x H x W x D) dimension and therefore should operate on the features again. It is therefore the same as your A_rearranged example but with an additional dimension for the feature dimension.
Thanks for the perfect reply. I really should study harder.
Thank you! XD
plus) I think I was confused between dividing patch (4x4 initially) and attention windowing (7x7). Sorry for bothering you.