This repository contains code for the paper "StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets"
by Axel Sauer, Katja Schwarz, and Andreas Geiger.
If you find our code or paper useful, please cite
@InProceedings{Sauer2021ARXIV,
author = {Axel Sauer and Katja Schwarz and Andreas Geiger},
title = {StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets},
journal = {arXiv.org},
volume = {abs/2201.00273},
year = {2022},
url = {https://arxiv.org/abs/2201.00273},
}
Rank on Papers With Code | |
---|---|
- Initial code release
- Add pretrained models (ImageNet{16,32,64,128}, FFHQ256, Pokemon256)
- Add higher resolution models (ImageNet{256,512,1024})
- Add PTI for inversion
- Add StyleMC for editing
The higher resolution models are currently retraining with improved settings, weights will be gradually rolled out. Expected release of the ImageNet256 model is 14.04.2022.
- 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
- CUDA toolkit 11.1 or later.
- GCC 7 or later compilers. The recommended GCC version depends on your CUDA version; see for example, CUDA 11.4 system requirements.
- If you run into problems when setting up the custom CUDA kernels, we refer to the Troubleshooting docs of the original StyleGAN3 repo.
- Windows user struggling installing the env might find autonomousvision#10 helpful.
- Use the following commands with Miniconda3 to create and activate your PG Python environment:
conda env create -f environment.yml
conda activate sgxl
For a quick start, you can download the few-shot datasets provided by the authors of FastGAN. You can download them here. To prepare the dataset at the respective resolution, run
python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \
--resolution=256x256 --transform=center-crop
You need to follow our progressive growing scheme to get the best results. Therefore, you should prepare separate zips for each training resolution. You can get the datasets we used in our paper at their respective websites (FFHQ, ImageNet).
For progressive growing, we train a stem on low resolution, e.g., 162 pixels. When the stem is finished, i.e., FID is saturating, you can start training the upper stages; we refer to these as superresolution stages.
Training StyleGAN-XL on Pokemon using 8 GPUs:
python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon16.zip \
--gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10
--batch
specifies the overall batch size, --batch-gpu
specifies the batch size per GPU. The training loop will automatically accumulate gradients if you use fewer GPUs until the overall batch size is reached.
Samples and metrics are saved in outdir
. If you don't want to track metrics, set --metrics=none
. You can inspect fid50k_full.json or run tensorboard in training-runs/
to monitor the training progress.
For a class-conditional dataset (ImageNet, CIFAR-10), add the flag --cond True
. The dataset needs to contain the class labels; see the StyleGAN2-ADA repo on how to prepare class-conditional datasets.
Continuing with pretrained stem:
python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon32.zip \
--gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10 \
--superres --up_factor 2 --head_layers 7 \
--path_stem training-runs/pokemon/00000-stylegan3-t-pokemon16-gpus8-batch64/best_model.pkl
--up_factor
allows to train several stages at once, i.e., with --up_factor=4
and a 162 stem you can directly train at resolution 642.
For unimodal datasets, we recommend using fewer layers, e.g., --head_layers 4
.
If you have enough compute, a good tactic is to train several stages in parallel and then restart the superresolution stage training once in a while. The current stage will then reload its previous stem's best_model.pkl
. Performance can sometimes drop at first because of domain shift, but the superresolution stage quickly recovers and improves further.
To generate samples and interpolation videos, run
python gen_images.py --outdir=out --trunc=0.7 --seeds=10-15 --batch-sz 1 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl
and
python gen_video.py --output=lerp.mp4 --trunc=0.7 --seeds=0-31 --grid=4x2 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl
For class-conditional models, you can pass the class index via --class
, a index-to-label dictionary for Imagenet can be found here.
To generate a conditional sample sheet, run
python gen_samplesheet.py --outdir=sample_sheets --trunc=1.0 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
--samples-per-class 4 --classes 0-32 --grid-width 32 \\
For the ImageNet models, we enable class-wise multi-modal truncation (a fast and class-conditional version of the truncation method by Self-Distilled GAN). We generate 60k class-conditional latents and find 30 cluster centroids via k-means. For a given samples, multi-modal truncation finds the closest centroids and interpolates towards it. To switch from uni-model to multi-modal truncation, pass
--centroids-path=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet_centroids.npy
No Truncation | Uni-Modal Truncation | Multi-Modal Truncation |
---|---|---|
We provide the following pretrained models (pass the url as PATH_TO_NETWORK_PKL
):
Dataset | Res | FID | PATH |
---|---|---|---|
ImageNet | 162 | 0.73 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet16.pkl |
ImageNet | 322 | 1.11 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet32.pkl |
ImageNet | 642 | 1.52 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet64.pkl |
ImageNet | 1282 | 1.77 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl |
CIFAR10 | 322 | 1.85 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/cifar10.pkl |
FFHQ | 2562 | 2.19 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl |
Pokemon | 2562 | 23.97 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl |
The weights for the ImageNet models at 642 and higher are currently still getting updated. If you cannot reproduce the reported FID via calc_metrics.py
(see below) you are likely using an older cached network pkl. Delete the previous model weights in your cache folder at $HOME/.cache/dnnlib/downloads/
.
Last update on 05.04.2022.
Per default, train.py
tracks FID50k during training. To calculate metrics for a specific network snapshot, run
python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL
To see the available metrics, run
python calc_metrics.py --help
We provide precomputed FID statistics for all pretrained models:
wget https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/gan-metrics.zip
unzip gan-metrics.zip -d dnnlib/
This repo builds on the codebase of StyleGAN3 and our previous project Projected GANs Converge Faster.