about seq2seq(attention)-Torch multiple sample training question
wmathor opened this issue · 0 comments
wmathor commented
hello, first thank your code, but i want to know if batch_size is more than 1, i should how to modify the code, thank you
def get_att_weight(self, output, enc_output): # get attention weight one 'output' with 'enc_output'
'''
output: [1, batch_size, num_directions(=1) * n_hidden]
enc_output: [n_step+1, batch_size, num_directions(=1) * n_hidden]
'''
length = len(enc_output)
attn_scores = torch.zeros(length) # attn_scores : [batch_size, n_step+1]
for i in range(length):
attn_scores[i] = self.get_att_score(output, enc_output[i])
# Normalize scores to weights in range 0 to 1
# return [batch_size, 1, n_step+1]
return F.softmax(attn_scores).view(batch_size, 1, -1)
def get_att_score(self, output, enc_output):
'''
output: [batch_size, num_directions(=1) * n_hidden]
enc_output: [batch_size, num_directions(=1) * n_hidden]
'''
score = self.attn(enc_output) # score : [1, n_hidden]
return torch.dot(output.view(-1), score.view(-1)) # inner product make scalar value, get a real number