keitakurita/Practical_NLP_in_PyTorch

Transformer - decoder block does not use encoder output for keys and values in attention mechanism

coxy1989 opened this issue · 1 comments

From the transformer implementation here:

class DecoderBlock(nn.Module):
    level = TensorLoggingLevels.enc_dec_block
    def __init__(self, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, 
                src_mask=None, tgt_mask=None):
        # Apply attention to inputs
        att = self.masked_attn_head(x, x, x, mask=src_mask)
        x = x + self.dropout(self.layer_norm1(att))
        # Apply attention to the encoder outputs and outputs of the previous layer
        att = self.attn_head(queries=att, keys=x, values=x, mask=tgt_mask)
        x = x + self.dropout(self.layer_norm2(att))
        # Apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In the forward mehtod, should:

att = self.attn_head(queries=att, keys=x, values=x, mask=tgt_mask)

Not be:

att = self.attn_head(queries=att, keys=enc_out, values=enc_out, mask=tgt_mask)

Nice catch, thanks for filing the issue! Fixed this in most recent commit