tensorflow/text

Transformer example masking is incorrect when using Keras MultiHeadedAttention

JRice15 opened this issue · 1 comments

Referring to the Transformer example.

The script probably works correctly as is. However, there is one part where you mention if you want to use tf.keras MultiHeadedAttention layer, instead of the custom one created here, the look-ahead mask has to be inverted (because the layer attends to 1's in the padding mask, not 0's). What is missed is the fact that the padding mask also has to be inverted, since it is used for attention, e.g. in the DecoderLayer. That change would also necessitate some changes to the loss function and* Transformer.create_masks() (and anywhere else the padding mask is used). Having two versions of all these functions might be too big of a hassle, but at least commenting where each change should be made would be useful. Even better would be to update the custom multi headed attention so that the padding is consistent with Keras's approach.

* I realize now the loss mask is computed directly from the target, and not passed to the function, so that part is still fine

Thanks for pointing this out. Would you want to update the example and send a PR that would explain this enough for future readers?