lucidrains/magvit2-pytorch

Flash attention not working on A100 GPU

jpfeil opened this issue · 2 comments

jpfeil commented

I'm trying to train the model on Imagenet, but I'm running into issues getting the model and data to fit in the GPU memory. I'm trying to use A100 gpus, but when the trainer runs I get this error:

File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
  return forward_call(*args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 385, in forward
  x = super().forward(x, *args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 375, in forward
  out = self.attend(q, k, v)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
  return forward_call(*args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 235, in forward
  return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 191, in flash_attn
  out = F.scaled_dot_product_attention(
RuntimeError: No available kernel.  Aborting execution

I think this is related to this issue: lucidrains/x-transformers#143

Is there a workaround for this issue?

Thank you!

I also ran into this issue, using A100 GPUs. My workaround was to bypass using Flash attention by commenting out the follow lines in "attend.py"

if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)

Without this, it should default to "math or mem efficient attention", based on the print statement on the following lines. Training works with those lines commented out!

I'm investigating further but figured I'd share this in case it's helpful for anyone in the meantime 🫡

@jacobpfeil @timlenardo yeah, i'm going to remove all the manual checks

researchers are telling me that pytorch 2.1 flash attention works much more seamlessly