Clarification Needed on Fused Axial Attention in FAX module
Closed this issue · 2 comments
In the local-global attention block of the CrossViewSwapAttention class, I noticed that there are two rearrange operations applied to the key tensor: From my understanding, these two operations seem to cancel each other out as they appear to reshape the key tensor first into a global feature map and then back into the original window partitioned shape. Could you help explain the purpose of these operations? Why does the key tensor need to be reshaped twice in this way?
# local-to-local cross-attention
query = rearrange(query, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
w1=self.q_win_size[0], w2=self.q_win_size[1]) # window partition
key = rearrange(key, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # window partition
val = rearrange(val, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # window partition
query = rearrange(self.cross_win_attend_1(query, key, val,
skip=rearrange(x,
'b d (x w1) (y w2) -> b x y w1 w2 d',
w1=self.q_win_size[0], w2=self.q_win_size[1]) if self.skip else None),
'b x y w1 w2 d -> b (x w1) (y w2) d') # reverse window to feature 全部恢复原来的形状
query = query + self.mlp_1(self.prenorm_1(query))
x_skip = query
query = repeat(query, 'b x y d -> b n x y d', n=n) # b n x y d
# local-to-global cross-attention
query = rearrange(query, 'b n (x w1) (y w2) d -> b n x y w1 w2 d',
w1=self.q_win_size[0], w2=self.q_win_size[1]) # window partition
# Todo: 这不是相互抵消的操作吗?
key = rearrange(key, 'b n x y w1 w2 d -> b n (x w1) (y w2) d') # reverse window to feature
key = rearrange(key, 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # grid partition
val = rearrange(val, 'b n x y w1 w2 d -> b n (x w1) (y w2) d') # reverse window to feature
val = rearrange(val, 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # grid partition
query = rearrange(self.cross_win_attend_2(query,
key,
val,
skip=rearrange(x_skip,
'b (x w1) (y w2) d -> b x y w1 w2 d',
w1=self.q_win_size[0],
w2=self.q_win_size[1])
if self.skip else None),
'b x y w1 w2 d -> b (x w1) (y w2) d') # reverse grid to feature
It is different for sure. In the FAX-SA:
# x: b l c h w
# mask: b h w 1 l
# window attention -> grid attention
mask_swap = mask
# mask b h w 1 l -> b x y w1 w2 1 L
mask_swap = rearrange(mask_swap,
'b (x w1) (y w2) e l -> b x y w1 w2 e l',
w1=self.window_size, w2=self.window_size)
x = rearrange(x, 'b m d (x w1) (y w2) -> b m x y w1 w2 d',
w1=self.window_size, w2=self.window_size)
x = self.window_attention(x, mask=mask_swap)
x = self.window_ffd(x)
x = rearrange(x, 'b m x y w1 w2 d -> b m d (x w1) (y w2)')
# grid attention
mask_swap = mask
mask_swap = rearrange(mask_swap,
'b (w1 x) (w2 y) e l -> b x y w1 w2 e l',
w1=self.window_size, w2=self.window_size)
x = rearrange(x, 'b m d (w1 x) (w2 y) -> b m x y w1 w2 d',
w1=self.window_size, w2=self.window_size)
x = self.grid_attention(x, mask=mask_swap)
x = self.grid_ffd(x)
x = rearrange(x, 'b m x y w1 w2 d -> b m d (w1 x) (w2 y)')
The feature first group as 'b m d (x w1) (y w2) -> b m x y w1 w2 d' , which will perform attention on the local window. Then it groups as ` x = rearrange(x, 'b m d (w1 x) (w2 y) -> b m x y w1 w2 d', which will perform attention globally as w1 w2 now represent the number of grids.
Similarly, in FAX CA, you can see the key and value are 'b n d (x w1) (y w2) -> b n x y w1 w2 d' for local attention, while latter they become 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d'. Different from FAX-SA, in FAX-CA, the query is always local window for the best performance. But you can change to grid attention as well.