Problem in position embedding
jmercat opened this issue · 8 comments
Line 129 in 619a8b3
It seems to me that the rotary position embedding is being applied on the head dimension (dim -2) of the vectors q, k instead of the sequence dimension (dim 1).
I think the head and sequence dimensions should be swapped before calling position embedding .
(see https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py#L85)
What I'm proposing is simply to re-write RotaryWithCast as follow:
class RotaryWithCast(RotaryEmbedding):
def forward(self, q, k, v):
q, k = super().forward(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3))
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
return q.to(v.dtype), k.to(v.dtype), v
Good catch! The blow up curves your are seeing are similar to the ones we were seeing before we introduced qk norm for the smaller models. Will do some testing with this fix on my end as well. Would you like to open a PR?
Wow, amazing catch! We really appreciate this.
We've added your name to the README because this is a very substantial bug catch. It's pretty interesting that our first 1B/7B runs do pretty well even without proper posembeds, but we should fix this going forward.
Great code base by the way. It's a pleasure to read.
Thanks for proposing to include me. I could open a PR but it's probably simpler for you to just include what I wrote (or a better version... I haven't tested if calling contiguous
would make a difference).
looking into a way to implement this directly with the xformers api. thanks so much @jmercat !
actually moving that line before the call to view would be enough.
Line 129 in 9b3ca53
The problem actually seems to be upstream in xformers. Opened an issue here: facebookresearch/xformers#841