graykode/nlp-tutorial

about seq2seq(attention)-Torch multiple sample training question

wmathor opened this issue · 0 comments

hello, first thank your code, but i want to know if batch_size is more than 1, i should how to modify the code, thank you

    def get_att_weight(self, output, enc_output):  # get attention weight one 'output' with 'enc_output'
        '''
        output: [1, batch_size, num_directions(=1) * n_hidden]
        enc_output: [n_step+1, batch_size, num_directions(=1) * n_hidden]
        '''
        length = len(enc_output)
        attn_scores = torch.zeros(length)  # attn_scores : [batch_size, n_step+1]
        for i in range(length):
            attn_scores[i] = self.get_att_score(output, enc_output[i])

        # Normalize scores to weights in range 0 to 1
        # return [batch_size, 1, n_step+1]
        return F.softmax(attn_scores).view(batch_size, 1, -1)

    def get_att_score(self, output, enc_output):
        '''
        output: [batch_size, num_directions(=1) * n_hidden]
        enc_output: [batch_size, num_directions(=1) * n_hidden]
        '''
        score = self.attn(enc_output)  # score : [1, n_hidden]
        return torch.dot(output.view(-1), score.view(-1))  # inner product make scalar value, get a real number