GAN-Tree: An Incrementally Learned Hierarchical Generative Framework for Multi-Modal Data Distributions
This repository contains code for the paper GAN-Tree: An Incrementally Learned Hierarchical Generative Framework for Multi-Modal Data Distributions
, published in ICCV 2019.
The full paper can be found here. If you find our research work helpful, please consider citing:
@article{kundu2019gan,
title={GAN-Tree: An Incrementally Learned Hierarchical Generative Framework for Multi-Modal Data Distributions},
author={Kundu, Jogendra Nath and Gor, Maharshi and Agrawal, Dakshit and Babu, R Venkatesh},
journal={arXiv preprint arXiv:1908.03919},
year={2019}
}
- Overview of the Model
- Setup Instructions and Dependencies
- Training GAN Tree from Scratch
- Repository Overview
- Experiments
- Results Obtained
- Guidelines for Contributors
- License
The overall GAN Tree architecture is given in the above figure. For further details about the architecture and training algorithm, please go through the paper.
You may setup the repository on your local machine by either downloading it or running the following line on cmd prompt
:
git clone https://github.com/val-iisc/GANTree.git
All dependencies required by this repo can be downloaded by creating a virtual or conda environment with Python 2.7 and running
pip install -r requirements.txt
The LSUN Bedroom Scene
and CelebA
datasets required for training can be found in the Google Drive link given inside the data/datasets.txt
file.
- Make sure to have the proper CUDA version installed for PyTorch v0.4.1.
- The code will not run on Windows since Pytorch v0.4.1 with Python 2.7 is not supported on it.
To train your own GAN Tree from scratch, run
python GANTree.py -hp path/to/hyperparams -en exp_name
- The hyperparameters for the experiment should be set in the
hyperparams.py
file (checksrc/hyperparams
for examples). - The training script will create a folder
experiments/exp_name
as specified in thehyperparams.py
file or argument passed in the command line to the-en
flag. - This folder will contain all data related to the experiment such as generated images, logs, plots, and weights. It will also contain a dump of the hyperparameters.
- Training will require a large amount of RAM.
- Saving GNodes requires ample amount of space (~500 MB per node).
The following argument flags are available for training:
-hp
,--hyperparams
: path to thehyperparam.py
file to be used for training.-en
,--exp_name
: experiment name.-g
,--gpu
: index of the gpu to be used. The default value is0
.-t
,--tensorboard
: iftrue
, start Tensorboard with the experiment. The default value isfalse
.-r
,--resume
: iftrue
, the training resumes from the latest step. The default value isfalse
.-d
,--delete
: delete the entities from the experiment file. The default value is[]
. The choices are['logs', 'weights', 'results', 'all']
.-w
,--weights
: the weight type to load if resume flag is provided. The default value isiter
. The choices are['iter', 'best_gen', 'best_pred']
.
This repository contains the following folders:
-
data: Contains the various datasets.
-
experiments: Contains data for different runs.
-
resources: Contains resources for the README.
-
src: Contains all the source code.
i. base: Contains the code for all base classes.
ii. dataloaders: Contains various dataloaders.
iii. hyperparams: Contains different
hyperparam.py
files for running various experiments.iv. models: Contains code for constructing AAE models, GNode and GAN-Tree.
v. modules and utils: Contains code for various functions used frequently.
vi. trainers: Contains code for the trainers of a particular GNode.
We train GAN Tree on the MNIST dataset, which is a single channel dataset consisting of handwritten datasets. To run the experiment, the following command can be executed:
python GANTree_MNIST.py
We train GAN Tree on the MNIST and Fashion MNIST dataset mixed together to test its performance on datasets with a clear discontinuous manifold. To run the experiment, the following command can be executed:
python GANTree_MNIST_Fashion_Mixed.py
We train GAN Tree on the LSUN Bedroom Scene and CelebA dataset mixed together to test its robustness in the multiple channel scenario. To run the experiment, the following command can be executed:
python GANTree_FaceBed.py
GAN Tree has the unique feature of being able to learn new related data without the need of previous data, i.e. learn incrementally. To run the experiment mentioned in the paper, the following command can be executed:
python GANTree_MNIST_0to4.py
After creating a GAN Tree trained on the digits 0-4 from the MNIST dataset, we add the digit 5. To incrementally learn a GAN Tree for the same, run the following command:
python iGANTree_add5_dsigma4.py
or
python iGANTree_add5_dsigma9.py
If you'd like to report a bug or open an issue then please:
Check if there is an existing issue. If there is, then please add any more information that you have, or give it a 👍.
When submitting an issue please describe the issue as clearly as possible, including how to reproduce the bug. If you can include a screenshot of the issues, that would be helpful.
Please first discuss the change you wish to make via an issue.
We don't have a set format for Pull Requests, but expect you to list changes, bugs generated and other relevant things in the PR message.
This repository is licensed under MIT license.