Would using FlashAttention(2) help?
jmschrei opened this issue · 4 comments
Hello!
Thank you for all your contributions.
I noticed that you have a built-in implementation of attention. I've been running into some memory restrictions when using the full model. Do you think that it would run faster/with less memory if the built-in attention was swapped with FlashAttention?
Thanks!
yup, definitely
in fact, i think everyone is now forced to use flash attention for genomic training, given the SOTA results from hyena dna
for the pretrained model, the attention in enformer unfortunately cannot be swapped out for flash, because it is using an old type of relative positional encoding (shaw's). however, if you are training from scratch, you should def fork and choose a flash-attention compatible RPE
Thanks for the clarification. Would you mind going back in time and telling them to use a different relative positional encoding?
Sorry for the naive question, but what makes the positional encodings from Enformer incompatible with flash attention, and relatedly, which could work instead (also with the genomic setting)?
I'm behind on this lit, so any insight would be greatly appreciated, thanks!
@Vejni rotary embeddings is the tried and tested solution