facebookresearch/xformers

RotaryEmbedding applied to the incorrect channel dimension

sagadre opened this issue ยท 3 comments

๐Ÿ› Bug

Input tensors to attention must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embedding size per head as documented here.

Hence positional embedding (e.g., rotary embedding) should be applied to dim=1. However, in the RotaryEmbedding class, dim=-2 is being passed, which corresponds to dim=2 as seen here.

def forward(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
            k, seq_dimension=-2 # should be seq_dimension=1 or no argument should be passed as the default value is correct
        )

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )

Additional context

Thanks to @jmercat who found symptoms of this problem downstream of xformers!

Hi @sagadre

Here is my understanding.

Before splitting your tensor into H heads, the shape of the tensor is [B, M, D], where B is batch size, M is sequence length, D is embedding dim.
After thte split into heads, the shape is [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embedding size per head.

If you use RotaryEmbedding before the split into heads, seq_dimension=-2 corresponds to seq_dimension=1 which is M. This is actually correct.

Are you applying rotary embedding before or after splitting the input into H heads ?

Hi again @sagadre

It looks like you are right !

I looked at MultiHeadDispatch in xformers, which relies on RotaryEmbedding, and indeed it is used after the split into H heads.

See

q, k = self.rotary_embeddings(q=q, k=k)

Hi again @sagadre

If you look at the unit test for rotary embedding, the input shape is (BATCH, HEADS, SEQ, EMB) and not (BATCH, SEQ, HEADS, EMB):

(BATCH, HEADS, SEQ, EMB), device=device, dtype=dtype

So there is a transpose here for rotary embedding:

return t.view(B, S, H, Hs).transpose(1, 2)