/STDiffProject

[AAAI'24] "STDiff: Spatio-temporal Diffusion for Continuous Stochastic Video Prediction". Xi Ye, Guillaume-Alexandre Bilodeau

Primary LanguagePython

STDiff: Spatio-temporal diffusion for continuous stochastic video prediction

arXiv | code

STDiff_BAIR_15

Overview

STDiff Architecture

Installation

  1. Install the custom diffusers library
git clone https://github.com/XiYe20/CustomDiffusers.git
cd CustomDiffusers
pip install -e .
  1. Install the requirements of STDiff
pip install -r requirements.txt

Datasets

Processed KTH dataset: https://drive.google.com/file/d/1RbJyGrYdIp4ROy8r0M-lLAbAMxTRQ-sd/view?usp=sharing
SM-MNIST: https://drive.google.com/file/d/1eSpXRojBjvE4WoIgeplUznFyRyI3X64w/view?usp=drive_link

For other datasets, please download them from the official website. Here we show the dataset folder structure.

BAIR

Please download the original BAIR dataset and utilize the "/utils/read_BAIR_tfrecords.py" script to convert it into frames as follows:

/BAIR
     test/
         example_0/
            0000.png
            0001.png
            ...
         example_1/
            0000.png
            0001.png
            ...
         example_...
     train/
         example_0/
            0000.png
            0001.png
            ...
         example_...

Cityscapes

Please download "leftImg8bit_sequence_trainvaltest.zip" from the official website. Center crop and resize all the frames to the size of 128X128. Save all the frames as follows:

/Cityscapes
     test/
         berlin/
            berlin_000000_000000_leftImg8bit.png
            berlin_000000_000001_leftImg8bit.png
            ...
         bielefeld/
            bielefeld_000000_000302_leftImg8bit.png
            bielefeld_000000_000302_leftImg8bit.png
            ...
         ...
     train/
         aachen/
            ....
         bochum/
            ....
         ...
     val/
            ....

KITTI

Please download the raw data (synced+rectified) from KITTI official website. Center crop and resize all the frames to the resolution of 128X128. Save all the frames as follows:

/KITTI
     2011_09_26_drive_0001_sync/
            0000000000.png
            0000000001.png
            ...
     2011_09_26_drive_0002_sync/
            ...
      ...

Training and Evaluation

The STDiff project uses accelerate for training. The training configuration files and test configuration files for different datasets are placed inside stdiff/configs.

Training

  1. Check train_script.sh, modify the visible gpus, num_process, select the correct train_cofig file
  2. Training
. ./train_script.sh

Test

  1. Check test_script.sh, select the correct test_cofig file
  2. Test
. ./test_script.sh

Citation

@inproceedings{ye2024stdiff,
  title={STDiff: Spatio-Temporal Diffusion for Continuous Stochastic Video Prediction},
  author={Ye, Xi and Bilodeau, Guillaume-Alexandre},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={38},
  number={7},
  pages={6666--6674},
  year={2024}
}

Uncurated prediction examples of STDiff for multiple datasets.

The temporal coordinates are shown at the top left corner of the frame. Frames with Red temporal coordinates denote future frames predicted by our model.

BAIR

STDiff_BAIR_0

STDiff_BAIR_15

SMMNIST

STDiff_SMMNIST_7

STDiff_SMMNIST_10

KITTI

STDiff_KITTI_0

STDiff_KITTI_22

Cityscapes

STDiff_City_110

STDiff_City_120