This repository offers TensorFlow implementations for many components related to Generative Adversarial Networks:
- losses (such non-saturating GAN, least-squares GAN, and WGAN),
- penalties (such as the gradient penalty),
- normalization techniques (such as spectral normalization, batch normalization, and layer normalization),
- neural architectures (BigGAN, ResNet, DCGAN), and
- evaluation metrics (FID score, Inception Score, precision-recall, and KID score).
The code is configurable via Gin and runs on GPU/TPU/CPUs. Several research papers make use of this repository, including:
-
Are GANs Created Equal? A Large-Scale Study [Code]
Mario Lucic*, Karol Kurach*, Marcin Michalski, Sylvain Gelly, Olivier Bousquet [NeurIPS 2018] -
The GAN Landscape: Losses, Architectures, Regularization, and Normalization [Code]
Karol Kurach*, Mario Lucic*, Xiaohua Zhai, Marcin Michalski, Sylvain Gelly [2018] -
Assessing Generative Models via Precision and Recall [Code]
Mehdi S. M. Sajjadi, Olivier Bachem, Mario Lucic, Olivier Bousquet, Sylvain Gelly [NeurIPS 2018] -
GILBO: One Metric to Measure Them All [Code]
Alexander A. Alemi, Ian Fischer [NeurIPS 2018] -
A Case for Object Compositionality in Deep Generative Models of Images [Code]
Sjoerd van Steenkiste, Karol Kurach, Sylvain Gelly [2018] -
On Self Modulation for Generative Adversarial Networks [Code]
Ting Chen, Mario Lucic, Neil Houlsby, Sylvain Gelly [ICLR 2019] -
Self-Supervised Generative Adversarial Networks [Code]
Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic, Neil Houlsby [CVPR 2019] -
High-Fidelity Image Generation With Fewer Labels [Code]
Mario Lucic*, Michael Tschannen*, Marvin Ritter*, Xiaohua Zhai, Olivier Bachem, Sylvain Gelly [2019]
You can easily install the library and all necessary dependencies by running:
pip install -e .
from the compare_gan/
folder.
Simply run the main.py
passing a --model_dir
(this is where checkpoints are
stored) and a --gin_config
(defines which model is trained on which data set
and other training options). We provide several example configurations in the
example_configs/
folder:
- dcgan_celeba64: DCGAN architecture with non-saturating loss on CelebA 64x64px
- resnet_cifar10: ResNet architecture with non-saturating loss and spectral normalization on CIFAR-10
- resnet_lsun-bedroom128: ResNet architecture with WGAN loss and gradient penalty on LSUN-bedrooms 128x128px
- sndcgan_celebahq128: SN-DCGAN architecture with non-saturating loss and spectral normalization on CelebA-HQ 128x128px
- biggan_imagenet128: BigGAN architecture with hinge loss and spectral normalization on ImageNet 128x128px
To see all available options please run python main.py --help
. Main options:
- To train the model use
--schedule=train
(default). Training is resumed from the last saved checkpoint. - To evaluate all checkpoints use
--schedule=continuous_eval --eval_every_steps=0
. To evaluate only checkpoints where the step size is divisible by 5000, use--schedule=continuous_eval --eval_every_steps=5000
. By default, 3 averaging runs are used to estimate the Inception Score and the FID score. Keep in mind that when running locally on a single GPU it may not be possible to run training and evaluation simultaneously due to memory constraints. - To train and evaluate the model use
--schedule=eval_after_train --eval_every_steps=0
.
We recommend using the
ctpu tool to create
a Cloud TPU and corresponding Compute Engine VM. We use v3-128 Cloud TPU v3 Pod
for training models on ImageNet in 128x128 resolutions. You can use smaller
slices if you reduce the batch size (options.batch_size
in the Gin config) or
model parameters. Keep in mind that the model quality might change. Before
training make sure that the environment variable TPU_NAME
is set. Running
evaluation on TPUs is currently not supported. Use a VM with a single GPU
instead.
Compare GAN uses TensorFlow Datasets and
it will automatically download and prepare the data. For ImageNet you will need
to download the archive yourself. For CelebAHq you need to download and prepare
the images on your own. If you are using TPUs make sure to point the training
script to your Google Storage Bucket (--tfds_data_dir
).