Dao-AILab/flash-attention

Need `tests/__init__.py` for `hopper/test_flash_attn.py`

Opened this issue · 2 comments

Description

Hi,

When I ran hopper/test_flash_attn.py for FA3, I encountered the following error.

~/flash-attention$ pytest -q -s hopper/test_flash_attn.py

================================================================================================================================================= ERRORS ==================================================================================================================================================
___________________________________________________________________________________________________________________________________ ERROR collecting test_flash_attn.py ___________________________________________________________________________________________________________________________________
ImportError while importing test module '/app-home/hancheol/flash-attention/hopper/test_flash_attn.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../anaconda3/envs/mllm-engine-py-3.11-cuda-12.4/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
hopper/test_flash_attn.py:14: in <module>
    from tests.test_util import (
E   ModuleNotFoundError: No module named 'tests.test_util'
========================================================================================================================================= short test summary info =========================================================================================================================================
ERROR hopper/test_flash_attn.py
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1 error in 1.58s

It seems that it can't import tests/test_util.py module.

After I created tests/__init__.py, It worked as follows.

~/flash-attention$ touch tests/__init__.py
~/flash-attention$ pytest -q -s hopper/test_flash_attn.py
torch.float8_e4m3fn
causal False
local False
gqa_parallel False
...

Best wishes,
Han-Cheol

I will check this to see if it solved the issue on my side too, thanks @hancheolcho !

seems to work, thanks!

maybe create a PR @hancheolcho ?