Jamie-Stirling/RetNet

Chunkwise retention giving different output

Closed this issue · 4 comments

The implementation of chunkwise retention paradigm on the chunkwise-real branch gives different outputs to the other two paradigms.

It appears there may be a mistake in the paper on which the implementation was based, in equation (7). A pull request fixing this and obtaining outputs consistent with the other two paradigms would be greatly appreciated.

This can be reproduced by running `python src/tests.py', with stdout:

FFF
======================================================================
FAIL: test_retnet (__main__.TestRetNet)
verify that the three implementations of RetNet are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 137, in test_retnet
    self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true

======================================================================
FAIL: test_multiscale (__main__.TestRetention)
verify that the three implementations of MultiScaleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 86, in test_multiscale
    self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true

======================================================================
FAIL: test_simple (__main__.TestRetention)
verify that the three implementations of SimpleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 45, in test_simple
    assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) # fails
AssertionError

----------------------------------------------------------------------
Ran 3 tests in 0.098s

FAILED (failures=3)

You could also refer to microsoft/torchscale@bf65397

@donglixp Thanks so much for your comment, it was critical to solving this issue.

There was also another term that is omitted in the paper in equation (7) but is otherwise present in the torchscale implementation. Please see line 85 of retention.py:

r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1

In particular:

D[-1].view(1, chunk_size, 1)

@donglixp Thanks so much for your comment, it was critical to solving this issue.

There was also another term that is omitted in the paper in equation (7) but is otherwise present in the torchscale implementation. Please see line 85 of retention.py:

r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1

In particular:

D[-1].view(1, chunk_size, 1)

Equation(7) of the latest arXiv paper ( https://arxiv.org/pdf/2307.08621v4.pdf ) fixed the issue.