nshepperd/flash_attn_jax

ModuleNotFoundError: No module named `flash_attn_jax.flash_api`

Closed this issue · 6 comments

import flash_attn_jax.flash_api as flash_api

Did you install one of the released wheels with pip?

How can i do that? (Tried pip install flash_att_jax but got error)

I actually cloned the repo, because I couldn't understand how to install it from the below

To install: For now, download the appropriate release from the releases page and install it with pip.

Can you please tell me the steps to properly install it? Thanks

@nshepperd could you pls tell me how I can install the released wheels with pip?

Go here: https://github.com/nshepperd/flash_attn_jax/releases/tag/v0.1.0a3
And find the file that matches your python version and cuda version you're using with jax. Install it with pip install.

Thanks!