RotaryEmbedding applied to the incorrect channel dimension
sagadre opened this issue ยท 3 comments
๐ Bug
Input tensors to attention must be in format [B, M, H, K]
, where B
is the batch size, M
the sequence length, H
the number of heads, and K
the embedding size per head as documented here.
Hence positional embedding (e.g., rotary embedding) should be applied to dim=1
. However, in the RotaryEmbedding
class, dim=-2
is being passed, which corresponds to dim=2
as seen here.
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=-2 # should be seq_dimension=1 or no argument should be passed as the default value is correct
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
Additional context
Thanks to @jmercat who found symptoms of this problem downstream of xformers!
Hi @sagadre
Here is my understanding.
Before splitting your tensor into H
heads, the shape of the tensor is [B, M, D]
, where B
is batch size, M
is sequence length, D
is embedding dim.
After thte split into heads, the shape is [B, M, H, K]
, where B
is the batch size, M
the sequence length, H
the number of heads, and K
the embedding size per head.
If you use RotaryEmbedding
before the split into heads, seq_dimension=-2
corresponds to seq_dimension=1
which is M
. This is actually correct.
Are you applying rotary embedding before or after splitting the input into H
heads ?
Hi again @sagadre
It looks like you are right !
I looked at MultiHeadDispatch in xformers, which relies on RotaryEmbedding, and indeed it is used after the split into H heads.
See
Hi again @sagadre
If you look at the unit test for rotary embedding, the input shape is (BATCH, HEADS, SEQ, EMB) and not (BATCH, SEQ, HEADS, EMB):
xformers/tests/test_rotary_embeddings.py
Line 61 in 748c159
So there is a transpose here for rotary embedding: