ModuleNotFoundError: No module named `flash_attn_jax.flash_api`
Closed this issue · 6 comments
VachanVY commented
nshepperd commented
Did you install one of the released wheels with pip?
VachanVY commented
How can i do that? (Tried pip install flash_att_jax but got error)
VachanVY commented
I actually cloned the repo, because I couldn't understand how to install it from the below
To install: For now, download the appropriate release from the releases page and install it with pip.
Can you please tell me the steps to properly install it? Thanks
VachanVY commented
@nshepperd could you pls tell me how I can install the released wheels with pip?
nshepperd commented
Go here: https://github.com/nshepperd/flash_attn_jax/releases/tag/v0.1.0a3
And find the file that matches your python version and cuda version you're using with jax. Install it with pip install
.
VachanVY commented
Thanks!