JunnYu/FLASHQuad_pytorch

关于rope的实现

Closed this issue · 4 comments

非常感谢你的实现,有个问题想咨询下,
在gau.py文件中,rope的实现如下:

def rope(x, dim):
    """RoPE position embedding."""
    shape = x.shape
    if isinstance(dim, int):
        dim = [dim]
    spatial_shape = [shape[i] for i in dim]
    total_len = 1
    for i in spatial_shape:
        total_len *= i
    position = torch.reshape(
        torch.arange(total_len, dtype=x.dtype,
                     device=x.device), spatial_shape
    )
    for i in range(dim[-1] + 1, len(shape) - 1, 1):
        position = position.unsqueeze(-1)
    half_size = shape[-1] // 2
    freq_seq = -torch.arange(half_size, dtype=x.dtype, device=x.device) / float(
        half_size
    )
    inv_freq = 10000 ** freq_seq
    sinusoid = torch.einsum("...,d->...d", position, inv_freq)
    sin = sinusoid.sin()
    cos = sinusoid.cos()
    x1, x2 = torch.chunk(x, 2, dim=-1)

    return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)

而在https://github.com/JunnYu/RoFormer_pytorch/blob/roformer_v2/src/roformer/modeling_roformer.py
rope的实现如下:

    def apply_rotary(x, sinusoidal_pos):
        sin, cos = sinusoidal_pos
        x1, x2 = x[..., 0::2], x[..., 1::2]
        # 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以)
        # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
        # 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。
        return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)

主要的区别是x1, x2的定义,前者是按照前一半和后一半划分,后者是按奇数项和偶数项划分,前者实现的并不是rope,我测试过两种实现,效果出入较大,不知道是不是我理解有误。

  • 最后一个dim维度无论咋分都一样啊,苏神选用的是奇数项和偶数项划分(我为了能顺利加载苏神的权重,必须要与他一致才行),而谷歌的那个使用的是前一半和后一半划分(从零训练你前一半后一半都一样啊)。
  • 个人感觉如果都是从零训练,最后一个dim你无论怎么切分,都不会影响最终的结果的,这最后一个维度没有顺序关系的吧。
  • 还有你是怎么测试这两种实现的,在啥任务上?

测试的任务是roberta预训练,使用了两种编码方式后同期loss差1左右;感觉应该没有太大区别,但是测试下来确实不太一样。

那我就不清楚了,可以去问问苏神,这个问题

ok,谢谢解答