/VQ-VAE

Pytorch Implementation of "Neural Discrete Representation Learning (Van den Oord, 2017)"

Primary LanguagePython

PyPI - Downloads PyPI PyPI - Python Version

VQ-VAE

VQ-VAE implementation based on Pytorch, Pytorch Lightning, Anaconda-project and Hydra.

Install pip package

pip install vqvae

Note that pip package contains only model/ folder

Anaconda-project

  1. Clone the repository
git clone https://github.com/Michedev/VQ-VAE
  1. Install anaconda if you don't have it

Train

Train your model

anaconda-project run train-gpu

Note: First time will download and install all dependencies

You can also specify additional arguments according to config/train.yaml like

anaconda-project run train-cpu  # train on cpu

Project structure

├── data  # Data storage folder
├── callbacks  # train/test callbacks
├── config
│   ├── dataset  # Dataset config
│   ├── model  # Model config
│   ├── model_dataset  # model and dataset specific config
│   ├── test.yaml   # testing configuration
│   └── train.yaml  # training configuration
├── dataset  # Dataset definition
├── model  # Model definition
│   └── callbacks  # model callbacks
├── utils
│   ├── experiment_tools.py # Iterate over experiments
│   └── paths.py  # common paths
├── train.py  # Entrypoint point for training
├── test.py  # Entrypoint point for testing
├── anaconda-project.yml  # Project configuration
├── saved_models  # where models are saved
└── readme.md  # This file

Design keypoints

  • root folder should contain only entrypoints and folders
  • Add tasks to anaconda-project.yml via the command anaconda-project add-command

Anaconda-project FAQ

How to add a new command?

Example:

anaconda-project add-command generate "python ddpm_pytorch/generate.py

Mac OS support in lock file

[Short] Run these commands:

anaconda-project remove-packages cudatoolkit;
anaconda-project add-platforms osx-64;

[Long]

  1. Remove cudatoolkit dependency from anaconda-project.yml
anaconda-project remove-packages cudatoolkit
  1. Add Mac OS platform to anaconda-project-lock.yml:
anaconda-project add-platforms osx-64