Correct interaction between CLS token and RoPE
oasidorshin opened this issue ยท 5 comments
Sorry for asking here, but I couldn't find any answer to this both in papers and in this repo.
What is the correct interaction between CLS token and RoPE (or other positional encoding schemes)? Currently I just add CLS token manually as the first token, but does the position matter in this case?
@oasidorshin hey Oleg! so the answer to this is no one knows and i haven't read any papers trying to make this work. it may turn out to be the case that the network just figures it out (CLS token learns to ignore the rotations). you may be in a position to be the first to explore this and share your findings
however, the correct way would be a months work fusing the rotation of queries and keys into the flash attention kernel. i imagine passing some hyperparameters that builds a should rotate mask, with rotations between CLS tokens and all other tokens omitted. you can also do this manually within Attention
by breaking off the CLS token from queries and keys and building up the pre-softmax attention matrix that way
@oasidorshin another alternative is to just use the CLS token to pool the representations across all tokens at the penultimate layers through cross attention. i've seen this used in some vision transformers with success
@lucidrains Thanks for great answers! Yeah, I agree that the most correct way would be to create special masks that disable rotations for CLS tokens, but it seems to be very complicated to do so.
For people from the future here: I just add both CLS and memory tokens in the beginning, and it is working quite well with RoPE, at least nothing is breaking and it is learning well, but I'm working not with texts but with custom sequences. I will add more if I find something else.
@lucidrains Good idea about using CLS only at the penultimate layer btw, going to remember that
@oasidorshin sounds good, if you do discover that CLS tokens function well without much relative positional engineering, that is tweet worthy