berniwal/swin-transformer-pytorch

Shifting attention-calculating windows

Closed this issue · 2 comments

Hello, sir. A question popped up again, unfortunately.

I've followed your shifting code, and it seems to have a difference with (my comprehension of) the paper.
I understood the behavior of the original paper's window shifting as a black arrow in the image below (self-attention is calculated with elements inside of bold lines). The left red arrow points to the result of patch-wise rolling and the right red arrow points results of rolling the entire feature map.

In my opinion, self-attention should be computed according to the right-top figure, therefore, boxes of right-bottom should be used (green dot-line separates subwindows) which each region in the right-top figure preserves.

Please let me know if I misunderstood your code or something in the paper. Thanks a lot!

Additionally, this is how I mimicked your code:

import torch
from einops import rearrange
A = torch.Tensor(list(range(1, 17))).view(1, 4, 4)
A_patched = A.view(4, 2, 2).permute(1, 2, 0).view(1, 2, 2, 4)
A_patched_rolled = torch.roll(A_patched, shifts=(-1, -1), dims=(1, 2))
A_rearranged = rearrange(A, 'a (b c) (d e)->a (b d) (c e)', b=2, d=2)
A_rearranged_rolled = torch.roll(A_rearranged, shifts=(-1, -1), dims=(1, 2))
A_rearranged_rolled2 = torch.roll(A_rearranged, shifts=(1, 1), dims=(1, 2))

where A can be considered as a 4x4 feature map (though element order is not matched with image above), A_patched is a divided version of A, and A_patched_rolled is patch-wise shifted version of A_patched, following torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2)) in your code. A_rearranged is rearranged to match the image above.

<---A_patched<---A_patched_rolled

>>> A
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.],
         [13., 14., 15., 16.]]])
>>> A_patched
tensor([[[[ 1.,  5.,  9., 13.],
          [ 2.,  6., 10., 14.]],

         [[ 3.,  7., 11., 15.],
          [ 4.,  8., 12., 16.]]]])
>>> A_patched_rolled
tensor([[[[ 4.,  8., 12., 16.],
          [ 3.,  7., 11., 15.]],

         [[ 2.,  6., 10., 14.],
          [ 1.,  5.,  9., 13.]]]])
>>> A_rearranged
tensor([[[ 1.,  2.,  5.,  6.],
         [ 3.,  4.,  7.,  8.],
         [ 9., 10., 13., 14.],
         [11., 12., 15., 16.]]])
>>> A_rearranged_rolled
tensor([[[ 4.,  7.,  8.,  3.],
         [10., 13., 14.,  9.],
         [12., 15., 16., 11.],
         [ 2.,  5.,  6.,  1.]]])
>>> A_rearranged_rolled2
tensor([[[16., 11., 12., 15.],
         [ 6.,  1.,  2.,  5.],
         [ 8.,  3.,  4.,  7.],
         [14.,  9., 10., 13.]]])

Thanks again for reporting potential errors in the code.

I agree with you on how we should compute the rolling and that rolling over the patches results in a wrong feature map. I think however that the code is already doing this, as in the WindowAttention we get as input the output of the PatchMerging module which is a feature map of (B x H x W x D) where B (batch size), H (height), W (width), D (hidden dimension). We then first compute the cylic_shift of this feature map (where we roll along the H, W dimensions) and only later in

q, k, v = map(lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', h=h, w_h=self.window_size, w_w=self.window_size), qkv)

we create the windows and are moving to a (B x Heads x Windows x Elements in Window x Head Dimension) setting. For the reverse shift the same as we only shift back after:

out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)

which then again moves from (B x Heads x Windows x Elements in Window x Head Dimension) to the (B x H x W x D) dimension and therefore should operate on the features again. It is therefore the same as your A_rearranged example but with an additional dimension for the feature dimension.

Thanks for the perfect reply. I really should study harder.

Thank you! XD

plus) I think I was confused between dividing patch (4x4 initially) and attention windowing (7x7). Sorry for bothering you.