A collection of my PyTorch implementation of several deep generative models.
This repertory is in progressing, feel free to raise an issue if you find any bug.
- PyTorch >= 1.0 (This code was develop on 1.3.1, but it should also work fine on other version)
- tensorboard (tb-nightly)
- numpy, scipy (ndarry support)
- matplotlib, moviepy (visualizing result)
- tqdm (progress bar)
Recommend to setup with Anaconda
git clone https://github.com/IcarusWizard/Deep-Generative-Models
cd Deep-Generative-Models
pip install -e .
- Auto-regressive
- FLOW
- VAE
- GAN
If you want to add new dataset, you need to define a creator function which returns three torch.utils.data.Dataset
for training, validation, testing, and a dict holds the configuration of the dataset (c, h, w). Then you can add your custom loader through:
import degmo
def custom_creater():
......
return training_set, validation_set, testing_set, config
degmo.add_dataset('name', custom_creater)
logs/ # default tensorboard log folder
checkpoints/ # default checkpoints folder
degmo/ # main folder
data/ # dataset functions
config/ # default configurations
utils.py # shared utility functions
modules.py # shared utility modules
<method>/
trainer.py # training procedure
run_utils.py # runtime utility functions
utils.py # method's utility functions
modules.py # method's utility modules
<model.py> # model class
......
Run python -m degmo.train_<method> --dataset <dataset> --model <model>
to train in default configuration.
You can run python -m degmo.check_default_config <method>
to find the default configuration we provide, or just look inside degmo/config
folder.
If you want to tune some parameters for yourself, pass --custom
to the training script, run python -m degmo.train_<method> -h
to see all the parameters that you can tune.
Note: All the default configurations are tested on a single RTX 2080Ti GPU with 11G memory, if you cannot run some default configurations (i.e. Glow), please consider reduce the batch size or features in config file or with a custom mode.
During and after training, you can use tensorboard --logdir=logs
to monitor progress.
Run python -m degmo.test_<model> -h
for help.