No support attn_mask != None in flash_attention_n
Opened this issue · 0 comments
PeiqinSun commented
Thanks for your nick work first! But when I use the flash_attention_n, I found a bug which happened in setting attn_mask from None to attention_mask. How can I fix it? Look for your reply! Thanks
from flash_attention_softmax_n import flash_attention_n
attn_output_flash = flash_attention_n(
query=query_states,
key=key_states,
value=value_states,
softmax_n_param=self.softmax_n,
scale=None,
dropout_p=0.,
attn_mask=attention_mask,
attn_bias=None,
is_causal=False
)
The error message is
File "/home/sunpeiqin/projects/qkwaii_pretrain/training/modeling_llama.py", line 295, in forward
attn_output_flash = flash_attention_n(
File "/usr/local/lib/python3.10/dist-packages/flash_attention_softmax_n/core/flash_attn.py", line 117, in flash_attention_n
return scaled_dot_product_attention(
RuntimeError: No available kernel. Aborting execution.
The version of torch is '2.2.0+cu121'
The device is A100-80G