/pytorch-collaborative-gan-sampling

A basic PyTorch implementation of the Collaborative Sampling in Generative Adversarial Networks

Primary LanguagePython

Collaborative Sampling in Generative Adversarial Networks

This repository provides a basic PyTorch implementation of the Collaborative Sampling in Generative Adversarial Networks.

Requirements

To install requirements:

pip install -r requirements.txt

Training

To train the GANs, run this command:

python main.py --mode="train" --niter=10000

Sampling

To collaboratively sample from the trained GANs, run this command:

python main.py --mode="collab" --ckpt_num=3000 --niter=3000 --lrd=5e-2

Result

Real Vanilla Sampling
GAN 1K Iter
Vanilla Sampling
GAN 3K Iter
Vanilla Sampling
GAN 10K Iter
Collab Sampling
GAN 3K Iter
Samples
KDE

Citation

If you use this code for your research, please cite our paper.

@inproceedings{liu2019collaborative,
  title={Collaborative Sampling in Generative Adversarial Networks},
  author={Liu, Yuejiang and Kothari, Parth and Alahi, Alexandre},
  booktitle={Thirty-first AAAI conference on artificial intelligence},
  year={2020}
}