Changelog of official implementation
donglixp opened this issue · 5 comments
Thanks for the well-written package! The RetNet's official implementation had several updates at https://github.com/microsoft/unilm/blob/master/retnet/README.md#changelog .
It's an honor to get the acknowledgment from the author himself! I'm definitely planning to reference the official implementation more in the next release. Thanks for the pointers!
Update: Currently on the track to implement based on torchscale version.
Parallel and Recurrent are done, but having some issues with chunkwise.
Problem Description:
I'll leave a note of what hinders chunkwise from being equivalent to the parallel or recurrent forward. (Actually, the same problem is present in the official torchscale code too.)
The decays are normalized in RetNetRelPos
, as in here. However, this leads to differently scaled decays than parallel or recurrent.
In parallel, the decay with slen=4
looks sth like this:
(A)
[1 0 0 0]
[r 1 0 0]
[r^2 r 1 0]
[r^3 r^2 r 1]
and after dividing by scale
, it becomes
(B)
[1 0 0 0] / [1 ]
[r 1 0 0] / [r + 1 ]
[r^2 r 1 0] / [r^2 + r + 1 ]
[r^3 r^2 r 1] / [r^3 +r^2 + r + 1]
For chunkwise with slen=4, chunk_size=2
, the decay_mask
is the following after scaling:
(C)
[1 0] / [1 ]
[r 1] / [r + 1]
Since the cross_decay=r^2
, in cross-chunk aggregation from the formula
(D)
[r + 1 ] / [r + 1] * [ KV_{i-1} ] * r^2 + [r + 1] / [r + 1] * [ KV_i ]
= [r^3 + r^2] / [r + 1] * [ KV_{i-1} ] + [r + 1] / [r + 1] * [ KV_i ]
= [r^3 + r^2, r + 1] / [r + 1] * [[ KV_{i-1} ]
[ KV_i ]]
= [r^3 r^2 r 1] / [r + 1] * [[ KV_0 ]
[ KV_1 ]
[ KV_2 ]
[ KV_3 ]]
Compare (B) with (D) and notice that while the row vector has the correct number of exponents, the scale that divides it should be sum(r**i for i in range(4))
, but it is just sum(r**i for i in range(2))
.
Hi @syncdoth, Thanks for great implementation!
I think the difference between chunkwise and others came from group norm eps.
I checked when layernorm_eps
is 0, all results are same.
And author of the paper said, after training validation PPL is almost same, so it's not a big problem.
Hi @N0r9st, the incorrect results come from group norm eps. If eps=0, the chunk representation is the same as the parallel one in math. You can try that.
Another reason is that the initialization of Retention is small, which amplifies the difference. However, After training, the validation ppl will be almost the same.
You can check this issue: microsoft/torchscale#77
if __name__ == '__main__':
bsz = 1
seq_len = 6
hidden_size = 8
inputs = torch.randn(bsz, seq_len, hidden_size)
attention_mask = None
config = RetNetConfig(
vocab_size=51200,
initializer_range=0.02,
is_decoder=True,
pad_token_id=1,
eos_token_id=1,
output_retentions=False,
use_cache=True,
forward_impl='parallel',
activation_fn="swish",
dropout=0.0, # dropout probability
activation_dropout=0.0, # dropout probability after activation in FFN.
decoder_embed_dim=hidden_size, # decoder embedding dimension
decoder_value_embed_dim=hidden_size, # decoder value embedding dimension
decoder_ffn_embed_dim=hidden_size, # decoder embedding dimension for FFN
decoder_layers=1, # num decoder layers
decoder_retention_heads=4, # num decoder retention heads
decoder_normalize_before=True, # apply layer_norm before each decoder block
embedding_layer_norm=False, # add layer_norm to embedding
no_scale_embedding=False, # if True, dont scale embeddings
recurrent_chunk_size=3,
use_glu=True, # use GLU instead of FFN
z_loss_coeff=0.0, # coefficient for z loss: TODO: 1e-4
use_lm_decay=False,
deepnorm=False,
subln=False,
layer_norm_eps=0,
tie_word_embeddings=True,
)
attn = MultiScaleRetention(config)
attn.eval()
pos = RetNetRelPos(config)
pos_mode = "parallel"
pos_out = pos.forward(
slen=seq_len, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
)
parallel_outputs = attn.forward(
hidden_states=inputs, retention_mask=attention_mask, forward_impl=pos_mode, rel_pos=pos_out, use_cache=True,
)[0]
pos_mode = "recurrent"
recurrent_outputs = []
past_key_value = None
for i in range(seq_len):
pos_out = pos.forward(
slen=i, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
)
if attention_mask is not None:
attn_out = attn.forward(
hidden_states=inputs[:, i:i + 1], retention_mask=attention_mask[:, i:i + 1],
forward_impl=pos_mode, rel_pos=pos_out, past_key_value=past_key_value, use_cache=True
)
else:
attn_out = attn.forward(
hidden_states=inputs[:, i:i + 1], retention_mask=None,
forward_impl=pos_mode, rel_pos=pos_out, past_key_value=past_key_value, use_cache=True
)
past_key_value = attn_out[1]
recurrent_outputs.append(attn_out[0])
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
pos_mode = "chunkwise"
pos_out = pos.forward(
slen=seq_len, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
recurrent_chunk_size=config.recurrent_chunk_size
)
chunked_outputs = attn.forward(
hidden_states=inputs, retention_mask=None,
forward_impl=pos_mode, rel_pos=pos_out, use_cache=True,
)[0]
print("parallel", parallel_outputs)
print("========================================")
print("recurrent", recurrent_outputs)
print("========================================")
print("chunkwise", chunked_outputs)
parallel tensor([[[ 0.0661, 0.0118, -0.0007, -0.0371, 0.0423, -0.0271, -0.0077,
0.0454],
[ 0.1802, -0.0137, 0.0067, 0.0878, 0.0504, 0.0784, 0.0541,
0.0457],
[ 0.1028, 0.0411, 0.0235, -0.0271, 0.0276, -0.0670, -0.0984,
0.0819],
[ 0.0462, 0.0080, 0.1094, 0.0696, -0.0110, -0.1273, -0.0945,
0.0626],
[ 0.0629, -0.0184, 0.1094, 0.0507, -0.0508, -0.0019, -0.0762,
0.0003],
[-0.0948, 0.0072, -0.1295, -0.0596, 0.0233, 0.0751, 0.0730,
-0.0602]]], grad_fn=<UnsafeViewBackward0>)
========================================
recurrent tensor([[[ 0.0661, 0.0118, -0.0007, -0.0371, 0.0423, -0.0271, -0.0077,
0.0454],
[ 0.1802, -0.0137, 0.0067, 0.0878, 0.0504, 0.0784, 0.0541,
0.0457],
[ 0.1028, 0.0411, 0.0235, -0.0271, 0.0276, -0.0670, -0.0984,
0.0819],
[ 0.0462, 0.0080, 0.1094, 0.0696, -0.0110, -0.1273, -0.0945,
0.0626],
[ 0.0629, -0.0184, 0.1094, 0.0507, -0.0508, -0.0019, -0.0762,
0.0003],
[-0.0948, 0.0072, -0.1295, -0.0596, 0.0233, 0.0751, 0.0730,
-0.0602]]], grad_fn=<CatBackward0>)
========================================
chunkwise tensor([[[ 0.0661, 0.0118, -0.0007, -0.0371, 0.0423, -0.0271, -0.0077,
0.0454],
[ 0.1802, -0.0137, 0.0067, 0.0878, 0.0504, 0.0784, 0.0541,
0.0457],
[ 0.1028, 0.0411, 0.0235, -0.0271, 0.0276, -0.0670, -0.0984,
0.0819],
[ 0.0462, 0.0080, 0.1094, 0.0696, -0.0110, -0.1273, -0.0945,
0.0626],
[ 0.0629, -0.0184, 0.1094, 0.0507, -0.0508, -0.0019, -0.0762,
0.0003],
[-0.0948, 0.0072, -0.1295, -0.0596, 0.0233, 0.0751, 0.0730,
-0.0602]]], grad_fn=<UnsafeViewBackward0>)
@hyunwoongko that's true! I can also confirm that setting groupnorm_eps to 0 or small number (1e-15) removes the differences in outputs.
One problem might be that the kv_cache from chunkwise forward is still not exactly the same (groupnorm doesn't affect this). I'm sure there should be a way around it, haven't thought about it enough yet 😊