Chunk recurrent representation incorrect results
N0r9st opened this issue · 7 comments
I believe there should be some type of normalization mistake in chunk recurrent retention. Output of it does not match the ouput of a simple recurrent and parallel retention. Recurrent retention also does not match chunk recurrent with chunk = 1. Parallel and recurrent retention are matched good enough.
Code for reproduction is below. There is also my config of retnet since default is not working, I can share if neccecary.
import torch
# from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
from retnet_config import RetNetConfig
# from retnet import RetNetDecoder
config = RetNetConfig(vocab_size=384)
retnet = RetNetDecoder(config)
x = torch.arange(12)[None, ...]
x_e = torch.randn((1, 12, 384))
retnet.retnet_rel_pos.recurrent_chunk_size = 3
retention_rel_pos = retnet.retnet_rel_pos(12, False, chunkwise_recurrent=True)
out_rec, _ = retnet.layers[0](x_e, None, chunkwise_recurrent=True, retention_rel_pos=retention_rel_pos)
retnet.retnet_rel_pos.recurrent_chunk_size = 12
retention_rel_pos = retnet.retnet_rel_pos(12, False, chunkwise_recurrent=False)
out_par, _ = retnet.layers[0](x_e, None, chunkwise_recurrent=False, retention_rel_pos=retention_rel_pos)
# is big after chunk_size tokens
print((out_rec - out_par).abs().mean(-1))
retnet.retnet_rel_pos.recurrent_chunk_size = 1
state0 = {}
outs = []
for i in range(12):
retention_rel_pos = retnet.retnet_rel_pos(i + 1, activate_recurrent=True)
outs.append(
retnet.layers[0](x_e[:, i:i+1], state0, chunkwise_recurrent=False, retention_rel_pos=retention_rel_pos)
)
outsr = torch.cat([x[0] for x in outs], dim=1)
# is small, ~1e-8
print((outsr - out_par).abs().mean(-1))
I managed to somehow match the 1-chunk and recurrent retention by playing with normalization and I managed to match retention hiddens of chunk recurrent and recurrent on arbitary chunk size (they are transposed lol), but I failed matching output on arbitary chunk size
Try to set eps=0.0 in group normalization. This may help
@XintianHan no, this does not work. I do not think that it is possible fix this by hyperparameter tuning
@sunyt32 maybe it is possible for you to look into this? I saw in last commits that you already worked on some fixes in chunk recurrent representation, so you maybe more familiar with math there. Or you prove my statement about normalization mistake wrong, that would be great too
Hi @N0r9st, the incorrect results come from group norm eps. If eps=0
, the chunk representation is the same as the parallel one in math. You can try that.
Another reason is that the initialization of Retention is small, which amplifies the difference. However, After training, the validation ppl will be almost the same.
@sunyt32 setting layernorm_eps=0
in config does not help. Error is about the same
prints from my code with eps=0:
tensor([[0.0000, 0.0000, 0.0000, 0.0459, 0.0510, 0.0479, 0.0481, 0.0472, 0.0522,
0.0444, 0.0582, 0.0539]], grad_fn=<MeanBackward1>)
tensor([[6.5687e-08, 7.1295e-08, 7.4986e-08, 6.8375e-08, 7.7678e-08, 7.0286e-08,
7.0873e-08, 9.2880e-08, 7.8289e-08, 7.6868e-08, 8.4775e-08, 7.2299e-08]],
grad_fn=<MeanBackward1>)
and with default eps (1e-6):
tensor([[0.0000, 0.0000, 0.0000, 0.0545, 0.0493, 0.0538, 0.0536, 0.0464, 0.0571,
0.0372, 0.0568, 0.0560]], grad_fn=<MeanBackward1>)
tensor([[7.3882e-08, 7.9706e-08, 7.7125e-08, 7.6495e-08, 6.8850e-08, 8.4901e-08,
7.9160e-08, 7.7276e-08, 7.1198e-08, 7.8673e-08, 7.5728e-08, 7.2352e-08]],
grad_fn=<MeanBackward1>)
Hi @N0r9st, the incorrect results come from group norm eps. If
eps=0
, the chunk representation is the same as the parallel one in math. You can try that.Another reason is that the initialization of Retention is small, which amplifies the difference. However, After training, the validation ppl will be almost the same.
the most concerning part for me is the fact that chunk representation with chunk=1 does not match with recurrent representation, while recurrent and parallel representation are matched nicely
But it does match with recurrent representation if state is normed the other way
Hi @N0r9st, the incorrect results come from group norm eps. If
eps=0
, the chunk representation is the same as the parallel one in math. You can try that.
Another reason is that the initialization of Retention is small, which amplifies the difference. However, After training, the validation ppl will be almost the same.the most concerning part for me is the fact that chunk representation with chunk=1 does not match with recurrent representation, while recurrent and parallel representation are matched nicely
But it does match with recurrent representation if state is normed the other way
I think they use RMSNORM now. Did you check which norm you used? If it's RMSNORM, you probably need to set eps=0.0 in RMSNORM
@XintianHan @sunyt32 Guys you were right - layernorm_eps=0
does the job. In my reply above I did try setting layernorm_eps=0
, but I still got incorrect results: I think I made a mistake in code or something. My example in the header of the issue gives error 1e-8 everywhere when I set the eps parameter.
Big thanks and I am closing this issue!