VQ-VAE implementation based on Pytorch, Pytorch Lightning, Anaconda-project and Hydra.
pip install vqvae
Note that pip package contains only model/ folder
- Clone the repository
git clone https://github.com/Michedev/VQ-VAE
- Install anaconda if you don't have it
Train your model
anaconda-project run train-gpu
Note: First time will download and install all dependencies
You can also specify additional arguments according to config/train.yaml
like
anaconda-project run train-cpu # train on cpu
├── data # Data storage folder
├── callbacks # train/test callbacks
├── config
│ ├── dataset # Dataset config
│ ├── model # Model config
│ ├── model_dataset # model and dataset specific config
│ ├── test.yaml # testing configuration
│ └── train.yaml # training configuration
├── dataset # Dataset definition
├── model # Model definition
│ └── callbacks # model callbacks
├── utils
│ ├── experiment_tools.py # Iterate over experiments
│ └── paths.py # common paths
├── train.py # Entrypoint point for training
├── test.py # Entrypoint point for testing
├── anaconda-project.yml # Project configuration
├── saved_models # where models are saved
└── readme.md # This file
- root folder should contain only entrypoints and folders
- Add tasks to anaconda-project.yml via the command
anaconda-project add-command
Example:
anaconda-project add-command generate "python ddpm_pytorch/generate.py
[Short] Run these commands:
anaconda-project remove-packages cudatoolkit;
anaconda-project add-platforms osx-64;
[Long]
- Remove cudatoolkit dependency from anaconda-project.yml
anaconda-project remove-packages cudatoolkit
- Add Mac OS platform to anaconda-project-lock.yml:
anaconda-project add-platforms osx-64