RotaryEmbedding applied to the incorrect channel dimension
sagadre opened this issue ยท 0 comments
sagadre commented
๐ Bug
Input tensors to attention must be in format [B, M, H, K]
, where B
is the batch size, M
the sequence length, H
the number of heads, and K
the embedding size per head as documented here.
Hence positional embedding (e.g., rotary embedding) should be applied to dim=1
. However, in the RotaryEmbedding
class, dim=-2
is being passed, which corresponds to dim=2
as seen here.
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=-2 # should be seq_dimension=1 or no argument should be passed as the default value is correct
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
Additional context
Thanks to @jmercat who found symptoms of this problem downstream of xformers!