Shape mismatch when number of heads > 1 and causal = False
Closed this issue · 1 comments
dmasmont commented
Hi,
I am having a shape mismatch in long_short_transformer.py , class LongShortAttention, line 183, when number of heads > 1 and causal = False:
pkv.masked_fill_(~mask[..., None], mask_value)
The mask seems to not be properly repeated by the number of heads.
Replacing the line with the following one fixed the bug:
pmask = repeat(mask, 'b n -> (b h) n', h = h)
pkv.masked_fill_(~pmask[..., None], mask_value)
Does that look correct to you?
lucidrains commented
@dmasmont https://github.com/lucidrains/long-short-transformer/releases/tag/0.0.5 thank you for catching this bug!