RUCAIBox/TextBox

Duplicate code in RNNVAE.py

Closed this issue · 1 comments

if self.rnn_type == "lstm": self.hidden_to_mean = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.hidden_to_logvar = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.latent_to_hidden = nn.Linear(self.latent_size, 2 * self.hidden_size) elif self.rnn_type == 'gru' or self.rnn_type == 'rnn': self.hidden_to_mean = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.hidden_to_logvar = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.latent_to_hidden = nn.Linear(self.latent_size, 2 * self.hidden_size)
Is there any difference between LSTM and GRU branches?

Thanks for your correction. For gru and rnn, latent_to_hidden should be nn.Linear(self.latent_size, self.hidden_size).