naver/croco

Is there any bug in the pytorch RoPE codes?

Closed this issue · 2 comments

x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]

Following the original implementation of RoPE (https://github.com/ZhuiyiTechnology/roformer), this line should be

x1, x2 = x[..., ::2], x[..., 1::2]
return torch.cat((-x2, x1), dim=-1)

Hi @ewrfcas

I'm not too sure which piece of code you're referring to. We based our implementation on the default implementation in FlashAttention's RoPE (interleave = False)
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L13-L19

Anyway, passing from interleaved to non-interleaved is easy (just a view + permute).

Thanks for your reply. I will further check about the code in FlashAttention!