Faster attention calculation in 4-2.Seq2Seq?
shouldsee opened this issue · 1 comments
shouldsee commented
Thanks for sharing! Just found out Attention.get_att_weight
is calculating attention in a for-loop? this looks rather slow isn't it?
4-2.Seq2Seq(Attention)/Seq2Seq(Attention).ipynb
def get_att_weight(self, dec_output, enc_outputs): # get attention weight one 'dec_output' with 'enc_outputs'
n_step = len(enc_outputs)
attn_scores = torch.zeros(n_step) # attn_scores : [n_step]
for i in range(n_step):
attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i])
# Normalize scores to weights in range 0 to 1
return F.softmax(attn_scores).view(1, 1, -1)
def get_att_score(self, dec_output, enc_output): # enc_outputs [batch_size, num_directions(=1) * n_hidden]
score = self.attn(enc_output) # score : [batch_size, n_hidden]
return torch.dot(dec_output.view(-1), score.view(-1)) # inner product make scalar value
Suggested parallel version
def get_att_weight(self, dec_output, enc_outputs): # get attention weight one 'dec_output' with 'enc_outputs'
n_step = len(enc_outputs)
attn_scores = torch.zeros(n_step,device=self.device) # attn_scores : [n_step]
enc_t = self.attn(enc_outputs)
score = dec_output.transpose(1,0).bmm(enc_t.transpose(1,0).transpose(2,1))
out1 = score.softmax(-1)
return out1
Ekundayo39283 commented
You can create a pull request to update the code