lucidrains/long-short-transformer

Shape mismatch when number of heads > 1 and causal = False

Closed this issue · 1 comments

Hi,

I am having a shape mismatch in long_short_transformer.py , class LongShortAttention, line 183, when number of heads > 1 and causal = False:
pkv.masked_fill_(~mask[..., None], mask_value)

The mask seems to not be properly repeated by the number of heads.

Replacing the line with the following one fixed the bug:

pmask = repeat(mask, 'b n -> (b h) n', h = h)
pkv.masked_fill_(~pmask[..., None], mask_value)

Does that look correct to you?