Training Generative Adversarial Networks (GANs) on high-fidelity images usually requires a vast number of training images. Recent research on GAN tickets reveals that dense GANs models contain sparse sub-networks or "lottery tickets" that, when trained separately, yield better results under limited data. However, finding GANs tickets requires an expensive process of train-prune-retrain. In this paper, we propose Re-GAN, a data-efficient GANs training that dynamically reconfigures GANs architecture during training to explore different sub-network structures in training time. Our method repeatedly prunes unimportant connections to regularize GANs network and regrows them to reduce the risk of prematurely pruning important connections. Re-GAN stabilizes the GANs models with less data and offers an alternative to the existing GANs tickets and progressive growing methods. We demonstrate that Re-GAN is a generic training methodology which achieves stability on datasets of varying sizes, domains, and resolutions (CIFAR-10, Tiny-ImageNet, and multiple few-shot generation datasets) as well as different GANs architectures (SNGAN, ProGAN, StyleGAN2 and AutoGAN). Re-GAN also improves performance when combined with the recent augmentation approaches. Moreover, Re-GAN requires fewer floating-point operations (FLOPs) and less training time by removing the unimportant connections during GANs training while maintaining comparable or even generating higher-quality samples. When compared to state-of-the-art StyleGAN2, our method outperforms without requiring any additional fine-tuning step.
Our codes were implemented by Pytorch, we list the libraries and their version used in our experiments, but other versions should also be worked.
- Linux (Ubuntu)
- Python (3.8.0)
- Pytorch (1.13.0+cu116)
- torchvision (0.14.0)
- numpy (1.23.4)
Should you have any questions about using this repo, feel free to contact Jiahao Xu @ jiahxu@polyu.edu.hk
Argument | Type | Description |
---|---|---|
epoch |
int | Number of total training epochs |
batch_size |
int | Batch size of per iteration, choose a proper value by yourselves |
regan |
store_true | Enable ReGAN training or not |
sparsity |
float | Target sparsity k, e.g. sparsity=0.3 means 30% of weights will be pruned |
g |
int | The update interval |
warmup_epoch |
int | Warmup training epochs |
data_ratio |
float | To simulate a training data limited scenario |
Argument | Type | Description |
---|---|---|
regan |
store_true | Enable ReGAN training or not |
sparsity |
float | Target sparsity k, e.g. sparsity=0.3 means 30% of weights will be pruned |
g |
int | The update interval |
warmup_epoch |
int | Warmup training epochs |
data_ratio |
float | To simulate a training data limited scenario |
For batch size of training epochs at each stage of ProGAN, you may define them on the main.py.
Argument | Type | Description |
---|---|---|
iter |
int | Number of total training iterations |
batch_size |
int | Batch size of per iteration, choose a proper value by yourselves |
regan |
store_true | Enable ReGAN training or not |
sparsity |
float | Target sparsity k, e.g. sparsity=0.3 means 30% of weights will be pruned |
g |
int | The update interval |
warmup_iter |
int | Warmup training iterations |
diffaug |
store_true | Enable DiffAug or not |
eva_iter |
int | Evaluation frequency |
eva_size |
int | Evaluation size, for few-shot dataset, we use eva_size=5000 |
dataset |
str | which dataset to use, please make sure you place the dataset dictionary in right place |
size |
int | Size of training image, for few-shot dataset, size=256 |
ckpt |
str | If you want to resume your training, you can use this argument |
Please refer to AutoGAN project website for getting detailed explanation.
Pytorch will download the CIFAR-10 dataset automatically if the dataset is not detected, therefore there is no need to prepare CIFAR-10 dataset.
For Tiny-ImageNet dataset, you may get it from https://www.kaggle.com/datasets/akash2sharma/tiny-imagenet.
For FFHQ dataset, you may get it from https://github.com/NVlabs/stylegan.
For Few-shot dataset, you may get it from https://github.com/odegeasslbc/FastGAN-pytorch. Besides, we provide the lmdb data of Few-shot dataset for your convenience, you may get it from here.
To run a Re-SNGAN or Re-ProGAN model, you may follow:
- Clone this repo to your local environment.
git clone https://github.com/IntellicentAI-Lab/Re-GAN.git
-
Prepare all the required libraries and datasets.
-
Run your model! One example can be:
# For SNGAN
python main.py --epoch 1000 --data_ratio 0.1 \
--regan --warmup_epoch 200 --g 100 --sparse 0.3
# For ProGAN
python main.py --data_ratio 0.1 \
--regan --warmup_epoch 10 --g 10 --sparse 0.3
# Please define epoch and batch size in the main.py
To run a Re-StyleGAN2 model on few-shot dataset, you may follow:
- Clone this repo to your local environment.
git clone https://github.com/IntellicentAI-Lab/Re-GAN.git
-
Prepare all the required libraries and datasets. For dataset, please place them like following panda example:
- ReGAN/dataset/panda
- ReGAN/dataset/pandalmdb
-
Run your model! One example can be:
python train.py --size 256 --batch 32 --iter 40000 --dataset panda --eva_iter 2000 --eva_size 100 \
--regan --warmup_iter 10000 --g 5000 --sparsity 0.3 \
--diffaug # Indicate it if you want to use DiffAugmentation
We conclude how we calculate some used metrics shown in our paper in this section.
Metrics | Description |
---|---|
#Real images |
The number of real images that are shown to the discriminator. You can calculate it by training epochs and batch size. |
FLOPs |
Floating Point Operations. You can calculate it by using thop library. |
Training time |
The total training time. Note that the training time reported in our paper is not included the evaluation time. |
MSE difference |
The MSE difference. We periodically (per epoch) save an image that contains 64 samples which are generated with fixed noise, then calculated the pixel difference. |
We provide some codes on quantitative evaluation, you can check the StyleGAN folder and see lines 347 to 379 on the train.py to learn how to evaluate FID. Calculating IS is a very similar process.
If you use this code for your research, please cite our papers.
@InProceedings{Saxena_2023_CVPR,
author = {Saxena, Divya and Cao, Jiannong and Xu, Jiahao and Kulshrestha, Tarun},
title = {Re-GAN: Data-Efficient GANs Training via Architectural Reconfiguration},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023},
pages = {16230-16240}
}
We would like to thank the work that helps our paper:
- FID score: https://github.com/bioinf-jku/TTUR.
- Inception score: https://github.com/w86763777/pytorch-gan-metrics.
- DiffAugmentation: https://github.com/VITA-Group/Ultra-Data-Efficient-GAN-Training.
- AutoGAN: https://github.com/VITA-Group/AutoGAN.
- APA: https://github.com/endlesssora/deceived.
- SNGAN: https://github.com/w86763777/pytorch-gan-collections.
- ProGAN: https://github.com/BakingBrains/Progressive_GAN-ProGAN-_implementation.
- StyleGAN2: https://github.com/rosinality/stylegan2-pytorch.