graykode/nlp-tutorial

Faster attention calculation in 4-2.Seq2Seq?

shouldsee opened this issue · 1 comments

Thanks for sharing! Just found out Attention.get_att_weight is calculating attention in a for-loop? this looks rather slow isn't it?

4-2.Seq2Seq(Attention)/Seq2Seq(Attention).ipynb

    def get_att_weight(self, dec_output, enc_outputs):  # get attention weight one 'dec_output' with 'enc_outputs'
        n_step = len(enc_outputs)
        attn_scores = torch.zeros(n_step)  # attn_scores : [n_step]

        for i in range(n_step):
            attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i])

        # Normalize scores to weights in range 0 to 1
        return F.softmax(attn_scores).view(1, 1, -1)

    def get_att_score(self, dec_output, enc_output):  # enc_outputs [batch_size, num_directions(=1) * n_hidden]
        score = self.attn(enc_output)  # score : [batch_size, n_hidden]
        return torch.dot(dec_output.view(-1), score.view(-1))  # inner product make scalar value

Suggested parallel version

    def get_att_weight(self, dec_output, enc_outputs):  # get attention weight one 'dec_output' with 'enc_outputs'
        n_step = len(enc_outputs)
        attn_scores = torch.zeros(n_step,device=self.device)  # attn_scores : [n_step]

        enc_t = self.attn(enc_outputs)
        score = dec_output.transpose(1,0).bmm(enc_t.transpose(1,0).transpose(2,1))
        out1   = score.softmax(-1)
        return out1

You can create a pull request to update the code