Dao-AILab/flash-attention

how to install flash_attn in torch==2.1.0

foreverpiano opened this issue · 4 comments

I try to install directly and build from source. Both of them fail.

  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/ao/quantization/__init__.py", line 3, 
in <module>
    from .fake_quantize import *  # noqa: F403
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/ao/quantization/fake_quantize.py", lin
e 8, in <module>
    from torch.ao.quantization.observer import (
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/ao/quantization/observer.py", line 15,
 in <module>
    from torch.ao.quantization.utils import (
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/ao/quantization/utils.py", line 12, in
 <module>
    from torch.fx import Node 
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/fx/__init__.py", line 83, in <module>
    from .graph_module import GraphModule
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/fx/graph_module.py", line 8, in <modul
e>
    from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/fx/graph.py", line 2, in <module>
    from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/fx/node.py", line 39, in <module>
    _ops.aten.sym_constrain_range_for_size.default,
  File "/home/dhl/miniconda3/envs/torch2.1/lib/python3.10/site-packages/torch/_ops.py", line 761, in __getattr__
    raise AttributeError(
AttributeError: '_OpNamespace' 'aten' object has no attribute 'sym_constrain_range_for_size'

env

accelerate==0.30.1
aiofiles==23.2.1
aiohttp==3.9.5
aiosignal==1.3.1
altair==5.3.0
annotated-types==0.7.0
anyio==4.4.0
async-timeout==4.0.3
attrs==23.2.0
certifi==2022.12.7
charset-normalizer==2.1.1
click==8.1.7
contourpy==1.2.1
cpm-kernels==1.0.11
cycler==0.12.1
distro==1.9.0
dnspython==2.6.1
docker-pycreds==0.4.0
email_validator==2.1.1
exceptiongroup==1.2.1
fastapi==0.111.0
fastapi-cli==0.0.4
ffmpy==0.3.2
filelock==3.13.1
fonttools==4.52.1
frozenlist==1.4.1
fschat==0.2.36
fsspec==2024.2.0
gitdb==4.0.11
GitPython==3.1.43
gradio==4.31.5
gradio_client==0.16.4
h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.2
idna==3.4
importlib_resources==6.4.0
Jinja2==3.1.3
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
-e git+https://github.com/RulinShao/LongChat-dev@3677918c376a6f5debddf1f2d74987e1b3ed93e4#egg=longchat
markdown-it-py==3.0.0
markdown2==2.4.13
MarkupSafe==2.1.5
matplotlib==3.9.0
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
networkx==3.2.1
nh3==0.2.17
numpy==1.26.3
openai==1.30.3
orjson==3.10.3
packaging==24.0
pandas==2.2.2
pillow==10.2.0
platformdirs==4.2.2
prompt-toolkit==3.0.43
protobuf==4.25.3
psutil==5.9.8
pydantic==2.7.1
pydantic_core==2.18.2
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
regex==2024.5.15
requests==2.32.2
rich==13.7.1
rpds-py==0.18.1
ruff==0.4.5
safetensors==0.4.3
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==2.3.1
setproctitle==1.3.3
shellingham==1.5.4
shortuuid==1.0.13
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
starlette==0.37.2
svgwrite==1.4.3
sympy==1.12
tiktoken==0.7.0
tokenizers==0.19.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.1.0+cu121
torchaudio==2.1.0+cu121
torchvision==0.16.0+cu121
tqdm==4.66.4
transformers==4.41.1
triton==2.1.0
typer==0.12.3
typing_extensions==4.9.0
tzdata==2024.1
ujson==5.10.0
urllib3==2.2.1
uvicorn==0.29.0
uvloop==0.19.0
wandb==0.17.0
watchfiles==0.22.0
wavedrom==2.0.3.post3
wcwidth==0.2.13
websockets==11.0.3
yarl==1.9.4

Looks like the error message is from within pytorch.

I build from new conda environment.
I first install torch2.1.0 with cuda12.1 and then install flash_attn or build from scratch. Both of them show error. So how can I fix it? The thing is I need to use torch<2.1.1
@tridao

Sorry I can't help with fixing torch error.

Have you successfully deployed flash_attn on a lower version of Torch before? If so, could you provide the script or log? I might be able to use it as a reference.
I mean building from old commit of flash_attention source code may work, but I don't know which commit I can choose.