bentrevett/pytorch-seq2seq

Tutorial 4: Decoder - the calculation of prediction

actforjason opened this issue · 2 comments

Why use torch.cat((output, weighted, embedded), dim=1)
Usually,Isn't just using the output enough?

        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))

We could just use output, but the notebook is replicating this paper which calculates the prediction using: the decoder hidden state (output), the attention weighted context (weighted) and the current input word (embedded) - see appendix section 2.2.

Maybe output is enough in this case. Feel free to try it and let me know if the results are any different.

We could just use output, but the notebook is replicating this paper which calculates the prediction using: the decoder hidden state (output), the attention weighted context (weighted) and the current input word (embedded) - see appendix section 2.2.

Maybe output is enough in this case. Feel free to try it and let me know if the results are any different.

thank you, I got it.