[Question] very small attention scores
pfeatherstone opened this issue · 7 comments
This is a general question when training transformers, but is related to a specific question on this repo.
In my tests I've found curiously that the attention layer's dot product matrix, sometimes called "scores", sometimes "dots" has very small values. In my case, all the non-masked elements have values in range [10^-9, 10^-8]. The masked elements have value -3.4028*10^+38 as expected. The attention map, which is just softmax(dots), yields 0 for masked elements and then exactly the same value for non-masked elements, presumably due to numerical imprecision. Has anybody ever seen this? If so, what are potential workarounds? This repo offers setting attn_qk_norm
which just L2 normalizes queries and keys. Does this solve the problem? Many thanks
I'll give it a go thanks
Funny how some publications can just be a case of: add conv there and see what happens
@lucidrains Good news, using attn_qk_norm
seems to have solved my problem. Now, all my attention "scores"/"dots" are order O(1), except for masked elements which are -3.4028*10^+38 as expected. So the softmaxed attention map is now looking more sensible.
It might be worth mentioning in the README that attn_qk_norm
can have this nice property. You mention already it can help with overflowing but it seems it can help with underflowing, or whatever this is.
Unfortunately, talking_heads
isn't compatible with flash attention. I can't afford not to use flash attention. I also had a look at sparse_topk
thinking that would also help, but again, not compatible with flash attention. Makes sense.
@pfeatherstone nice! yea i'm bullish on cosine sim attention. Tero Karras recently used it in his new u-net with great results
Makes you wonder, what percentage of a model is just some kind of normalization. Probably quite high. That seems like a flaw. Someone needs to invent a new neural network architecture where normalization is like < 1% of your layers.
@pfeatherstone nice! yea i'm bullish on cosine sim attention. Tero Karras recently used it in his new u-net with great results
What's the state of https://github.com/lucidrains/flash-cosine-sim-attention ? I like the idea of fusing flash attention with l2-normalized kv.
Also, did you consider using https://github.com/NVIDIA/cutlass for the CUDA backend? I think Tri Dao used that library for Flash Attention 2 and allowed him to write much more concise and ultimately better code. (According to a podcast interview)