v2.1.2 causes ImportError with Rotary
Closed this issue · 8 comments
The changes introduced in b28ec23 cause an ImportError:
from flash_attn.layers.rotary import apply_rotary_emb_func
File "/opt/conda/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 8, in <module>
from flash_attn.ops.triton.rotary import apply_rotary
ModuleNotFoundError: No module named 'flash_attn.ops.triton'
I would expect any added dependencies to be installed with the pip install
. Currently using
pip install flash-attn --no-build-isolation
pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary
Reverting to v2.1.1
solves this issue for me:
pip install flash-attn==2.1.1 --no-build-isolation
pip install git+https://github.com/HazyResearch/flash-attention.git@v2.1.1#subdirectory=csrc/rotary
Then need to wait when it's getting merged into main branch.
I think more than just the init file is missing. creating that init file wasn't enough to fix importing of rotary embeddings for me.
I think the entirety of ops/triton
is missing from the PyPI distribution.
Should be fixed in v2.2.0 (might take CI 9-10 hours to finish compiling all the CUDA wheels).
v2.2.0 was released, and ops/triton
is successfully included :)
But I am getting rotary issue #523 with manually compiled 2.2.0 since pypi is not ready.
Version: 2.1.2.post3 has the same problem,
pip install flash-attn==2.1.1 --no-build-isolation hang long time in china, so could you give another method to get the whl not by github(so fuck block)
so could you give another method to get the whl not by github
We welcome contributions on this front.