syncdoth/RetNet

Changelog of official implementation

donglixp opened this issue · 5 comments

Thanks for the well-written package! The RetNet's official implementation had several updates at https://github.com/microsoft/unilm/blob/master/retnet/README.md#changelog .

It's an honor to get the acknowledgment from the author himself! I'm definitely planning to reference the official implementation more in the next release. Thanks for the pointers!

Update: Currently on the track to implement based on torchscale version.

Parallel and Recurrent are done, but having some issues with chunkwise.

Problem Description:

I'll leave a note of what hinders chunkwise from being equivalent to the parallel or recurrent forward. (Actually, the same problem is present in the official torchscale code too.)

The decays are normalized in RetNetRelPos, as in here. However, this leads to differently scaled decays than parallel or recurrent.

In parallel, the decay with slen=4 looks sth like this:

(A)

[1    0   0  0]
[r    1   0  0]
[r^2  r   1  0]
[r^3  r^2 r  1]

and after dividing by scale, it becomes

(B)

[1    0   0  0]   /  [1               ]
[r    1   0  0]   /  [r + 1           ]
[r^2  r   1  0]   /  [r^2 + r + 1     ]
[r^3  r^2 r  1]   /  [r^3 +r^2 + r + 1]

For chunkwise with slen=4, chunk_size=2, the decay_mask is the following after scaling:

(C)

[1    0]  /  [1    ]
[r    1]  /  [r + 1]

Since the cross_decay=r^2, in cross-chunk aggregation from the formula $KV_{i} = KV_{i-1} * D + KV_{i}$ ($D$ = cross_decay=r^2), the decay applied to previous chunk becomes

(D)

  [r   + 1  ]  /  [r + 1]  *   [ KV_{i-1} ] * r^2   +    [r + 1]  /  [r + 1]  *  [  KV_i  ]

= [r^3 + r^2]  /  [r + 1]  *   [ KV_{i-1} ]         +    [r + 1]  /  [r + 1]  *  [  KV_i  ]


= [r^3 + r^2, r + 1] / [r + 1]   *   [[ KV_{i-1} ]
                                      [   KV_i   ]]

= [r^3   r^2  r   1] / [r + 1]   *   [[   KV_0   ]
                                      [   KV_1   ]
                                      [   KV_2   ]
                                      [   KV_3   ]]

Compare (B) with (D) and notice that while the row vector has the correct number of exponents, the scale that divides it should be sum(r**i for i in range(4)), but it is just sum(r**i for i in range(2)).

Hi @syncdoth, Thanks for great implementation!
I think the difference between chunkwise and others came from group norm eps.
I checked when layernorm_eps is 0, all results are same.
And author of the paper said, after training validation PPL is almost same, so it's not a big problem.

Hi @N0r9st, the incorrect results come from group norm eps. If eps=0, the chunk representation is the same as the parallel one in math. You can try that.

Another reason is that the initialization of Retention is small, which amplifies the difference. However, After training, the validation ppl will be almost the same.

You can check this issue: microsoft/torchscale#77

if __name__ == '__main__':
    bsz = 1
    seq_len = 6
    hidden_size = 8

    inputs = torch.randn(bsz, seq_len, hidden_size)
    attention_mask = None

    config = RetNetConfig(
        vocab_size=51200,
        initializer_range=0.02,
        is_decoder=True,
        pad_token_id=1,
        eos_token_id=1,
        output_retentions=False,
        use_cache=True,
        forward_impl='parallel',
        activation_fn="swish",
        dropout=0.0,  # dropout probability
        activation_dropout=0.0,  # dropout probability after activation in FFN.
        decoder_embed_dim=hidden_size,  # decoder embedding dimension
        decoder_value_embed_dim=hidden_size,  # decoder value embedding dimension
        decoder_ffn_embed_dim=hidden_size,  # decoder embedding dimension for FFN
        decoder_layers=1,  # num decoder layers
        decoder_retention_heads=4,  # num decoder retention heads
        decoder_normalize_before=True,  # apply layer_norm before each decoder block
        embedding_layer_norm=False,  # add layer_norm to embedding
        no_scale_embedding=False,  # if True, dont scale embeddings
        recurrent_chunk_size=3,
        use_glu=True,  # use GLU instead of FFN
        z_loss_coeff=0.0,  # coefficient for z loss: TODO: 1e-4
        use_lm_decay=False,
        deepnorm=False,
        subln=False,
        layer_norm_eps=0,
        tie_word_embeddings=True,
    )

    attn = MultiScaleRetention(config)
    attn.eval()
    pos = RetNetRelPos(config)

    pos_mode = "parallel"
    pos_out = pos.forward(
        slen=seq_len, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
    )
    parallel_outputs = attn.forward(
        hidden_states=inputs, retention_mask=attention_mask, forward_impl=pos_mode, rel_pos=pos_out, use_cache=True,
    )[0]

    pos_mode = "recurrent"
    recurrent_outputs = []
    past_key_value = None
    for i in range(seq_len):
        pos_out = pos.forward(
            slen=i, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
        )
        if attention_mask is not None:
            attn_out = attn.forward(
                hidden_states=inputs[:, i:i + 1], retention_mask=attention_mask[:, i:i + 1],
                forward_impl=pos_mode, rel_pos=pos_out, past_key_value=past_key_value, use_cache=True
            )
        else:
            attn_out = attn.forward(
                hidden_states=inputs[:, i:i + 1], retention_mask=None,
                forward_impl=pos_mode, rel_pos=pos_out, past_key_value=past_key_value, use_cache=True
            )
        past_key_value = attn_out[1]
        recurrent_outputs.append(attn_out[0])
    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)

    pos_mode = "chunkwise"
    pos_out = pos.forward(
        slen=seq_len, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
        recurrent_chunk_size=config.recurrent_chunk_size
    )
    chunked_outputs = attn.forward(
        hidden_states=inputs, retention_mask=None,
        forward_impl=pos_mode, rel_pos=pos_out, use_cache=True,
    )[0]

    print("parallel", parallel_outputs)
    print("========================================")
    print("recurrent", recurrent_outputs)
    print("========================================")
    print("chunkwise", chunked_outputs)
parallel tensor([[[ 0.0661,  0.0118, -0.0007, -0.0371,  0.0423, -0.0271, -0.0077,
           0.0454],
         [ 0.1802, -0.0137,  0.0067,  0.0878,  0.0504,  0.0784,  0.0541,
           0.0457],
         [ 0.1028,  0.0411,  0.0235, -0.0271,  0.0276, -0.0670, -0.0984,
           0.0819],
         [ 0.0462,  0.0080,  0.1094,  0.0696, -0.0110, -0.1273, -0.0945,
           0.0626],
         [ 0.0629, -0.0184,  0.1094,  0.0507, -0.0508, -0.0019, -0.0762,
           0.0003],
         [-0.0948,  0.0072, -0.1295, -0.0596,  0.0233,  0.0751,  0.0730,
          -0.0602]]], grad_fn=<UnsafeViewBackward0>)
========================================
recurrent tensor([[[ 0.0661,  0.0118, -0.0007, -0.0371,  0.0423, -0.0271, -0.0077,
           0.0454],
         [ 0.1802, -0.0137,  0.0067,  0.0878,  0.0504,  0.0784,  0.0541,
           0.0457],
         [ 0.1028,  0.0411,  0.0235, -0.0271,  0.0276, -0.0670, -0.0984,
           0.0819],
         [ 0.0462,  0.0080,  0.1094,  0.0696, -0.0110, -0.1273, -0.0945,
           0.0626],
         [ 0.0629, -0.0184,  0.1094,  0.0507, -0.0508, -0.0019, -0.0762,
           0.0003],
         [-0.0948,  0.0072, -0.1295, -0.0596,  0.0233,  0.0751,  0.0730,
          -0.0602]]], grad_fn=<CatBackward0>)
========================================
chunkwise tensor([[[ 0.0661,  0.0118, -0.0007, -0.0371,  0.0423, -0.0271, -0.0077,
           0.0454],
         [ 0.1802, -0.0137,  0.0067,  0.0878,  0.0504,  0.0784,  0.0541,
           0.0457],
         [ 0.1028,  0.0411,  0.0235, -0.0271,  0.0276, -0.0670, -0.0984,
           0.0819],
         [ 0.0462,  0.0080,  0.1094,  0.0696, -0.0110, -0.1273, -0.0945,
           0.0626],
         [ 0.0629, -0.0184,  0.1094,  0.0507, -0.0508, -0.0019, -0.0762,
           0.0003],
         [-0.0948,  0.0072, -0.1295, -0.0596,  0.0233,  0.0751,  0.0730,
          -0.0602]]], grad_fn=<UnsafeViewBackward0>)

@hyunwoongko that's true! I can also confirm that setting groupnorm_eps to 0 or small number (1e-15) removes the differences in outputs.

One problem might be that the kv_cache from chunkwise forward is still not exactly the same (groupnorm doesn't affect this). I'm sure there should be a way around it, haven't thought about it enough yet 😊