lucidrains/PaLM-rlhf-pytorch

Possible incorrect creation of Rotary Embeddinigs

AndyBarcia opened this issue · 1 comments

Disclaimer: I don't have any idea about how this codebase works. I was just trying to implement on my own Rotary Embeddings for a personal project, and I was using the class defined in palm.py as a starting point.

The thing is, I'm not sure if the current implementation of Rotary Embeddings is correct. Specifically, I don't think the following line is correct:

x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

Because for Rotary embeddings we want to swap pair of adjacent elements, and negate the even elements (aka, turn [1,2,3,4,5,6] into [-2,1,-4,3,-6,5]). But the previous code basically swaps the two halves of the tensor, and negates the first one (aka, turns [1,2,3,4,5,6] into [-4,-5,-6,1,2,3]).

Is the code incorrect or is there something I'm missing?

it is correct because the rotary embedding was concatted the same way https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/main/palm_rlhf_pytorch/palm.py#L83