DerrickXuNu/CoBEVT

Clarification Needed on Fused Axial Attention in FAX module

Closed this issue · 2 comments

In the local-global attention block of the CrossViewSwapAttention class, I noticed that there are two rearrange operations applied to the key tensor: From my understanding, these two operations seem to cancel each other out as they appear to reshape the key tensor first into a global feature map and then back into the original window partitioned shape. Could you help explain the purpose of these operations? Why does the key tensor need to be reshaped twice in this way?

    # local-to-local cross-attention
    query = rearrange(query, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
                      w1=self.q_win_size[0], w2=self.q_win_size[1])  # window partition
    key = rearrange(key, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
                      w1=self.feat_win_size[0], w2=self.feat_win_size[1])  # window partition
    val = rearrange(val, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
                      w1=self.feat_win_size[0], w2=self.feat_win_size[1])  # window partition
    query = rearrange(self.cross_win_attend_1(query, key, val,
                                            skip=rearrange(x,
                                                        'b d (x w1) (y w2) -> b x y w1 w2 d',
                                                         w1=self.q_win_size[0], w2=self.q_win_size[1]) if self.skip else None),
                   'b x y w1 w2 d  -> b (x w1) (y w2) d')    # reverse window to feature   全部恢复原来的形状

    query = query + self.mlp_1(self.prenorm_1(query))

    x_skip = query
    query = repeat(query, 'b x y d -> b n x y d', n=n)              # b n x y d

    # local-to-global cross-attention
    query = rearrange(query, 'b n (x w1) (y w2) d -> b n x y w1 w2 d',
                      w1=self.q_win_size[0], w2=self.q_win_size[1])  # window partition
    # Todo: 这不是相互抵消的操作吗?
    key = rearrange(key, 'b n x y w1 w2 d -> b n (x w1) (y w2) d')  # reverse window to feature
    key = rearrange(key, 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d',
                    w1=self.feat_win_size[0], w2=self.feat_win_size[1])  # grid partition
    val = rearrange(val, 'b n x y w1 w2 d -> b n (x w1) (y w2) d')  # reverse window to feature
    val = rearrange(val, 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d',
                    w1=self.feat_win_size[0], w2=self.feat_win_size[1])  # grid partition
    query = rearrange(self.cross_win_attend_2(query,
                                              key,
                                              val,
                                              skip=rearrange(x_skip,
                                                        'b (x w1) (y w2) d -> b x y w1 w2 d',
                                                        w1=self.q_win_size[0],
                                                        w2=self.q_win_size[1])
                                              if self.skip else None),
                   'b x y w1 w2 d  -> b (x w1) (y w2) d')  # reverse grid to feature

It is different for sure. In the FAX-SA:

        # x: b l c h w
        # mask: b h w 1 l
        # window attention -> grid attention
        mask_swap = mask

        # mask b h w 1 l -> b x y w1 w2 1 L
        mask_swap = rearrange(mask_swap,
                              'b (x w1) (y w2) e l -> b x y w1 w2 e l',
                              w1=self.window_size, w2=self.window_size)
        x = rearrange(x, 'b m d (x w1) (y w2) -> b m x y w1 w2 d',
                      w1=self.window_size, w2=self.window_size)
        x = self.window_attention(x, mask=mask_swap)
        x = self.window_ffd(x)
        x = rearrange(x, 'b m x y w1 w2 d -> b m d (x w1) (y w2)')

        # grid attention
        mask_swap = mask
        mask_swap = rearrange(mask_swap,
                              'b (w1 x) (w2 y) e l -> b x y w1 w2 e l',
                              w1=self.window_size, w2=self.window_size)
        x = rearrange(x, 'b m d (w1 x) (w2 y) -> b m x y w1 w2 d',
                      w1=self.window_size, w2=self.window_size)
        x = self.grid_attention(x, mask=mask_swap)
        x = self.grid_ffd(x)
        x = rearrange(x, 'b m x y w1 w2 d -> b m d (w1 x) (w2 y)')

The feature first group as 'b m d (x w1) (y w2) -> b m x y w1 w2 d' , which will perform attention on the local window. Then it groups as ` x = rearrange(x, 'b m d (w1 x) (w2 y) -> b m x y w1 w2 d', which will perform attention globally as w1 w2 now represent the number of grids.

Similarly, in FAX CA, you can see the key and value are 'b n d (x w1) (y w2) -> b n x y w1 w2 d' for local attention, while latter they become 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d'. Different from FAX-SA, in FAX-CA, the query is always local window for the best performance. But you can change to grid attention as well.