/convlstmgru

Pytorch implementations of ConvLSTM and ConvGRU modules with examples

Primary LanguagePythonMIT LicenseMIT

Codacy Badge License: MIT

ConvLSTM and ConvGRU | Pytorch

Implementation of ConvolutionalLSTM and ConvolutonalGRU in PyTorch

Inspired by this repository but has been refactored and got new features such as peephole option and usage examples in implementations of video predicton seq-to-seq models on moving MNIST dataset.

How to Use

The ConvLSTM and ConvGRU modules are inherited from torch.nn.Module.

The ConvLSTM and ConvGRU allow using any number of layers. In this case, it can be specified the hidden dimension (that is, the number of channels) and the kernel size of each layer. In the case more layers are present but a single value is provided, this is replicated for all the layers. For example, in the following snippet each of the three layers has a different hidden dimension but the same kernel size.

Short code snippet of usage

conv_lstm_encoder = ConvLSTM(
                   input_size=(hidden_spt,hidden_spt),
                   input_dim=hidden_dim,
                   hidden_dim=lstm_dims,
                   kernel_size=(3,3),
                   num_layers=3,
                   peephole=True,
                   batchnorm=False,
                   batch_first=True,
                   activation=F.tanh
                  )
                  
hidden = conv_lstm_encoder.get_init_states(batch_size)
output, encoder_state = conv_lstm_encoder(input, hidden)

Project Structure

Main Files

  • convlstm.py: contains main classes for ConvLSTMCell(represents one "layer") and ConvLSTM modules
  • convgru.py : same as for convlstm

Other

  • train_gru_predictor.py and train_lstm_predictor.py: train video prediction models based on ConvGru and ConvLSTM respectively
  • cnn.py: file that contains simple convolutional networks for encoding and decoding frames representations
  • bouncing_mnist.py: contains dataloader that generates moving MNIST dataset from plain MNIST on a fly, use this raw MNIST dataset for reproducing the experiments.
  • generate_test_set.py: used to generate testing data for trained models
  • test.py: contains tester for trained models

Prediction examples

For every 3 rows, 1st represent previous frames that are fed to the model, 2nd represent predicted frames and 3rd represent GT future frames:

Predictions