lucidrains/PaLM-rlhf-pytorch

speed up with flash attn in A6000?

wac81 opened this issue · 2 comments

wac81 commented

please check it.
https://www.reddit.com/r/StableDiffusion/comments/xmr3ic/speed_up_stable_diffusion_by_50_using_flash/

but it's not speed up use palm model with flash attn param in A6000 in my case.

PyTorch 2.0 Flash Attention requires a SM80 architecture. The A6000 has a SM86 architecture. It is not currently supported. And just to clarify again, you can not use a dim_head above 128.

wac81 commented

PyTorch 2.0 Flash Attention requires a SM80 architecture. The A6000 has a SM86 architecture. It is not currently supported. And just to clarify again, you can not use a dim_head above 128.

thank you a lot