nshepperd/flash_attn_jax

it will support group attention?

Closed this issue · 2 comments

flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)

except inputs like it

I worked on this today, seems like it works now so I'll push a new release w/ it once I've verified all tests.

Check https://github.com/nshepperd/flash_attn_jax/releases/tag/v0.2.0, should work now. Also you could install it with pip now, pip install flash-attn-jax==0.2.0, though the version on pypi is only for cuda 12.3 because... idk pypi is limited like that.