huggingface/transformers

[LLaMA] Rotary positional embedding differs with official implementation

lytning98 opened this issue · 9 comments

transformers implement LLaMA model's Rotary Positional Embedding (RoPE) as follows:

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

This is GPT-NeoX style RoPE. But in Meta's official model implementation, the model adopts GPT-J style RoPE, which processes query and key vectors in an interleaved way instead of split into two half (as in rotate_half method).

Meta's official repo implements RoPE as (full code link):

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

I'm confused with this difference, since transformers.LlamaModel can directly load weights converted from the officially released checkpoint, won't this lead to inconsistency in inference results? Is this difference expected?

same confusion

same confusion

@santiweide Params of some layers are re-permuted while converting weights in the official scripts. Check

# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

ohhh thank you, we are converting the Megatron weight to ft weight, and we would check the shape of weights then

Awesome, thanks for clarifying this!

Awesome, thanks for clarifying this!

Thanks for the detailed illustration!!!

Thank you @lytning98, your answer saved my life.

May I ask the purpose behind this process?

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

I mean, why not use the interleaved pair as in Meta's official llama?
@zphang @ArthurZucker .

Thanks in advance.

Thank you @lytning98, your answer saved my life.

May I ask the purpose behind this process?

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

I mean, why not use the interleaved pair as in Meta's official llama?
@zphang @ArthurZucker .
Thanks in advance.

https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2

Few reasons that are already mentioned:

  • first and foremost, and I can't stress this enough, licence
  • second, eleuther's rope formulation (that we are using) is equivalent, maybe has one less operation that makes it more optimised