wrong implementation for autoregressive self-attention
Sleepychord opened this issue · 10 comments
Hi, I found that you used fast_transfomers's CUDA Kernel, but it does not contain normalization
part, which needs a cumsum outside the CausalDotProduct (in causal_linear_attention).
If I didn't miss something, the result of your code should be wrong... But I am not 100% sure.
It seems the causal_linear_attention_noncuda is also wrong. Do you miss 1. / ...
in
D_inv = torch.einsum('...nd,...nd->...n', q, k.cumsum(dim=-2))
?
@Sleepychord oh my, yes, indeed there was a bug in the non-cuda version, thanks for catching that! https://github.com/lucidrains/performer-pytorch/releases/tag/0.11.2
as for EPFL's code, if you look at equation (4), i do believe they normalize https://arxiv.org/pdf/2006.16236.pdf
@Sleepychord they keep track of the cumulative sum within the CUDA code, that is how they make it memory efficient
@Sleepychord they keep track of the cumulative sum within the CUDA code, that is how they make it memory efficient
Actually they do not (I have read the CUDA code before posting this issue). Just as I said above, they use a cumsum
outside the CUDA code in a function called causal_linear_attention
, because cumsum
is not an expensive operation.
But you can easily solve it by adding a line of code similar to the first line in causal_linear_attention_noncuda
.
@Sleepychord omg! you are right https://github.com/idiap/fast-transformers/blob/d82aca326eb228fafdf2942c9c26c343f6d1b8ef/fast_transformers/attention/causal_linear_attention.py#L95 Fail on my part lol
Ok, will fix! Thank you so much!
@Sleepychord could you review #41
Is that a Yugi-oh character?
Is that a Yugi-oh character?
Ahhhh, Yeah, Yusei Fudo, the hero of 5DS.