Transformer - decoder block does not use encoder output for keys and values in attention mechanism
coxy1989 opened this issue · 1 comments
coxy1989 commented
From the transformer implementation here:
class DecoderBlock(nn.Module):
level = TensorLoggingLevels.enc_dec_block
def __init__(self, d_model=512, d_feature=64,
d_ff=2048, n_heads=8, dropout=0.1):
super().__init__()
self.masked_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
self.position_wise_feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model),
)
self.layer_norm1 = LayerNorm(d_model)
self.layer_norm2 = LayerNorm(d_model)
self.layer_norm3 = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_out,
src_mask=None, tgt_mask=None):
# Apply attention to inputs
att = self.masked_attn_head(x, x, x, mask=src_mask)
x = x + self.dropout(self.layer_norm1(att))
# Apply attention to the encoder outputs and outputs of the previous layer
att = self.attn_head(queries=att, keys=x, values=x, mask=tgt_mask)
x = x + self.dropout(self.layer_norm2(att))
# Apply position-wise feedforward network
pos = self.position_wise_feed_forward(x)
x = x + self.dropout(self.layer_norm2(pos))
return x
In the forward mehtod, should:
att = self.attn_head(queries=att, keys=x, values=x, mask=tgt_mask)
Not be:
att = self.attn_head(queries=att, keys=enc_out, values=enc_out, mask=tgt_mask)
keitakurita commented
Nice catch, thanks for filing the issue! Fixed this in most recent commit