Chunkwise retention giving different output
Closed this issue · 4 comments
The implementation of chunkwise retention paradigm on the chunkwise-real branch gives different outputs to the other two paradigms.
It appears there may be a mistake in the paper on which the implementation was based, in equation (7). A pull request fixing this and obtaining outputs consistent with the other two paradigms would be greatly appreciated.
This can be reproduced by running `python src/tests.py', with stdout:
FFF
======================================================================
FAIL: test_retnet (__main__.TestRetNet)
verify that the three implementations of RetNet are identical
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jamie/Repos/RetNet/src/tests.py", line 137, in test_retnet
self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true
======================================================================
FAIL: test_multiscale (__main__.TestRetention)
verify that the three implementations of MultiScaleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jamie/Repos/RetNet/src/tests.py", line 86, in test_multiscale
self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true
======================================================================
FAIL: test_simple (__main__.TestRetention)
verify that the three implementations of SimpleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jamie/Repos/RetNet/src/tests.py", line 45, in test_simple
assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) # fails
AssertionError
----------------------------------------------------------------------
Ran 3 tests in 0.098s
FAILED (failures=3)
You could also refer to microsoft/torchscale@bf65397
@donglixp Thanks so much for your comment, it was critical to solving this issue.
There was also another term that is omitted in the paper in equation (7) but is otherwise present in the torchscale implementation. Please see line 85 of retention.py
:
r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1
In particular:
D[-1].view(1, chunk_size, 1)
@donglixp Thanks so much for your comment, it was critical to solving this issue.
There was also another term that is omitted in the paper in equation (7) but is otherwise present in the torchscale implementation. Please see line 85 of
retention.py
:r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1In particular:
D[-1].view(1, chunk_size, 1)
Equation(7) of the latest arXiv paper ( https://arxiv.org/pdf/2307.08621v4.pdf ) fixed the issue.