berlino/gated_linear_attention

Question about masking

Closed this issue · 2 comments

Hi, I am very new to the triton code. I am curious about how is the causal mask implemented. Is it implicitly assumed in the triton code because you use the cumulative sum form? In particular, I wonder how this line and the line below implement the causal masking?

for interchunk ops, since there is no overlap between two consecutive chunks, so there is no causal mask.

for intrachunk ops, i have one in https://github.com/berlino/gated_linear_attention/blob/main/kernels/intra_chunk_contribution/fn_only_gk.py#L205C1-L206C1

Thanks a lot for the extremely prompt reply :)