softmax1/Flash-Attention-Softmax-N

No support attn_mask != None in flash_attention_n

Opened this issue · 0 comments

Thanks for your nick work first! But when I use the flash_attention_n, I found a bug which happened in setting attn_mask from None to attention_mask. How can I fix it? Look for your reply! Thanks

from flash_attention_softmax_n import flash_attention_n

attn_output_flash = flash_attention_n(
            query=query_states,
            key=key_states,
            value=value_states,
            softmax_n_param=self.softmax_n,
            scale=None,
            dropout_p=0.,
            attn_mask=attention_mask,
            attn_bias=None,
            is_causal=False
)

The error message is

 File "/home/sunpeiqin/projects/qkwaii_pretrain/training/modeling_llama.py", line 295, in forward
    attn_output_flash = flash_attention_n(
  File "/usr/local/lib/python3.10/dist-packages/flash_attention_softmax_n/core/flash_attn.py", line 117, in flash_attention_n
    return scaled_dot_product_attention(
RuntimeError: No available kernel. Aborting execution.

The version of torch is '2.2.0+cu121'
The device is A100-80G