Does the stateful implementataion work similar to the LSTMs statefulness.
giridhar13 opened this issue · 1 comments
Thank you for the implementation, can we use this implementation to train video classifiers where the input sequences have variable length. Also can we use the statefulness of the network to infer on single frames ,(by utilizing the states from the previous input) similar to the LSTM implementation. Currently I use the following definition of ConvLSTM where the input sequence length needs to be defined. If the testing happens on a variable sequence length, then matrices are impacted (classification worsens):
import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable
class ConvLSTMCell(nn.Module):
def init(self, input_size, hidden_size, kernel_size=3, stride=1, padding=1):
super(ConvLSTMCell, self).init()
self.input_size = input_size
self.hidden_size = hidden_size
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size=kernel_size,
stride=stride, padding=padding)
torch.nn.init.xavier_normal_(self.Gates.weight)
torch.nn.init.constant_(self.Gates.bias, 0)
def forward(self, input_, prev_state):
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = (Variable(torch.zeros(state_size).cuda()), Variable(torch.zeros(state_size).cuda()))
prev_hidden, prev_cell = prev_state
stacked_inputs = torch.cat((input_, prev_hidden), 1)
gates = self.Gates(stacked_inputs)
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
in_gate = torch.sigmoid(in_gate)
remember_gate = torch.sigmoid(remember_gate)
out_gate = torch.sigmoid(out_gate)
cell_gate = torch.tanh(cell_gate)
cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
hidden = out_gate * torch.tanh(cell)
return hidden, cell
and initiate the layer by:
conv_lstm = ConvLSTMCell(input_size, hidden_mem_size)
But I am stuck with fixed input sequences sizes and statelessness, does your implementation overcome these problems?