lucidrains/enformer-pytorch

Would using FlashAttention(2) help?

jmschrei opened this issue · 4 comments

Hello!

Thank you for all your contributions.

I noticed that you have a built-in implementation of attention. I've been running into some memory restrictions when using the full model. Do you think that it would run faster/with less memory if the built-in attention was swapped with FlashAttention?

Thanks!

yup, definitely

in fact, i think everyone is now forced to use flash attention for genomic training, given the SOTA results from hyena dna

for the pretrained model, the attention in enformer unfortunately cannot be swapped out for flash, because it is using an old type of relative positional encoding (shaw's). however, if you are training from scratch, you should def fork and choose a flash-attention compatible RPE

Thanks for the clarification. Would you mind going back in time and telling them to use a different relative positional encoding?

Vejni commented

Sorry for the naive question, but what makes the positional encodings from Enformer incompatible with flash attention, and relatedly, which could work instead (also with the genomic setting)?

I'm behind on this lit, so any insight would be greatly appreciated, thanks!

@Vejni rotary embeddings is the tried and tested solution