Flash attention not working on A100 GPU
jpfeil opened this issue · 2 comments
I'm trying to train the model on Imagenet, but I'm running into issues getting the model and data to fit in the GPU memory. I'm trying to use A100 gpus, but when the trainer runs I get this error:
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 385, in forward
x = super().forward(x, *args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 375, in forward
out = self.attend(q, k, v)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 235, in forward
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 191, in flash_attn
out = F.scaled_dot_product_attention(
RuntimeError: No available kernel. Aborting execution
I think this is related to this issue: lucidrains/x-transformers#143
Is there a workaround for this issue?
Thank you!
I also ran into this issue, using A100 GPUs. My workaround was to bypass using Flash attention by commenting out the follow lines in "attend.py"
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
Without this, it should default to "math or mem efficient attention", based on the print statement on the following lines. Training works with those lines commented out!
I'm investigating further but figured I'd share this in case it's helpful for anyone in the meantime 🫡
@jacobpfeil @timlenardo yeah, i'm going to remove all the manual checks
researchers are telling me that pytorch 2.1 flash attention works much more seamlessly