3-3-bilstm-torch comment error
Tonybb9089 opened this issue · 1 comments
Tonybb9089 commented
class BiLSTM(nn.Module):
def init(self):
super(BiLSTM, self).init()
self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)
self.W = nn.Parameter(torch.randn([n_hidden * 2, n_class]).type(dtype))
self.b = nn.Parameter(torch.randn([n_class]).type(dtype))
def forward(self, X):
input = X.transpose(0, 1) # input : [n_step, batch_size, n_class]
hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
cell_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
**outputs = outputs[-1] # [batch_size, n_hidden]**
model = torch.mm(outputs, self.W) + self.b # model : [batch_size, n_class]
return model
error: "outputs = outputs[-1] # [batch_size, n_hidden]"
the shape should be [batch_size,2*n_hidden]
wmathor commented
hey bro, i found this error too
i think you are right