The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems.
This repo first contains a PyTorch implementation of PredRNN (2017) [paper], a recurrent network with a pair of memory cells that operate in nearly independent transition manners, and finally form unified representations of the complex environment.
Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate.
This repo also includes the implementation of PredRNN-V2 (2021) [paper], which improves PredRNN in the following two aspects.
We find that the pair of memory cells in PredRNN contain undesirable, redundant features, and thus present a memory decoupling loss to encourage them to learn modular structures of visual dynamics.
Reverse scheduled sampling is a new curriculum learning strategy for seq-to-seq RNNs. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth. Benefits: (1) It makes the training converge quickly by reducing the encoder-forecaster training gap. (2) It enforces the model to learn more from long-term input context.
LPIPS is more sensitive to perceptual human judgments, the lower the better.
Moving MNIST | KTH action | |
---|---|---|
PredRNN | 0.109 | 0.204 |
PredRNN-V2 | 0.071 | 0.139 |
- Install Python 3.7, PyTorch 1.3, and OpenCV 3.4.
- Download data. This repo contains code for two datasets: the Moving Mnist dataset and the KTH action dataset.
- Train the model. You can use the following bash script to train the model. The learned model will be saved in the
--save_dir
folder. The generated future frames will be saved in the--gen_frm_dir
folder. - You can get pretrained models from here.
cd mnist_script/
sh predrnn_mnist_train.sh
sh predrnn_v2_mnist_train.sh
cd kth_script/
sh predrnn_kth_train.sh
sh predrnn_v2_kth_train.sh
If you find this repo useful, please cite the following papers.
@inproceedings{wang2017predrnn,
title={{PredRNN}: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal {LSTM}s},
author={Wang, Yunbo and Long, Mingsheng and Wang, Jianmin and Gao, Zhifeng and Yu, Philip S},
booktitle={Advances in Neural Information Processing Systems},
pages={879--888},
year={2017}
}
@misc{wang2021predrnn,
title={{PredRNN}: A Recurrent Neural Network for Spatiotemporal Predictive Learning},
author={Wang, Yunbo and Wu, Haixu and Zhang, Jianjin and Gao, Zhifeng and Wang, Jianmin and Yu, Philip S and Long, Mingsheng},
year={2021},
eprint={2103.09504},
archivePrefix={arXiv},
}