lucidrains/vit-pytorch

Layernorm in Cross attention

turtleman99 opened this issue · 4 comments

I'm wondering why we don't need layernorm for K and V, but need it for Q in cross attention. Is there any paper I can refer to? Thanks a lot.

def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = default(context, x)
if kv_include_self:
context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

it has been a while, but can you check to see if the context being received isn't coming from another transformer with a final layernorm?

oh, this is a good point. My x and context are actually features from pre-trained ViT models. I believe I can remove layernorm for x and context, right?

@turtleman99 if you are following the same scheme as cross vit, then the code is correct and stays the same. x is layernormed (pre-layernorm configuration + residual), while context is left alone since it is already layernormed from your pretrained ViT model. You would only need to layernorm the context if the context also cross attends to x and is updated, like in the ISAB architecture, but that isn't what the cross vit authors did

Gotcha. Thank you so much for your explanations! :)