Applying decoder input mask?
maxmax1992 opened this issue · 2 comments
maxmax1992 commented
Hi,
I'm trying to implement basic transformer architecture "Attention is all you need", but replacing MultiHeadAttention with the performer_pytorch.SelfAttention
, however the expected mask for decoder input is apparently not of shape n x n?
I've tried different setups, but no success. Any tips/ideas? I've only glanced through the paper.
lucidrains commented
@maxmax1992 Hi Maxim! You do not have to worry about passing in the NxN triangular mask for decoder. Simply set the causal
keyword argument to True
and it will be all taken care of!
maxmax1992 commented
Closing this as the solution is to pass causal=True to SelfAttention class.