/StyleCLIP-flax

🖼 Training StyleGAN2 on TPUs in JAX

Primary LanguageJupyter Notebook

StyleGAN2 Flax TPU

This implementation is adapted from the stylegan2 codebase by Matthias Wright.

Specifically, the features we've added allow for better scaling of StyleGAN2 training on TPUs:

  • 🏭 Enable data-parallel training on TPU pods (tested on TPU v2 to v4 generations)
  • 💾 Google Cloud Storage (GCS) integration/dataset sharding between workers
  • 🏖 Quality-of-life improvements (e.g. improved W&B logging)

This research is part of the technology underlying our AI-generated photography platform Nyx.gallery

This food does not exist! Click to see more samples 🍪🍰🍣🍹🍔

This food does not exist

🏗 Changelog

v0.2
  • Better support for class-conditional training, adding per-class moving average statistics to generator
  • Training data can now be split into multiple tfrecord files (can be either in --data_dir or in a subdirectory tfrecords). Still requires dataset_info.json in --data_dir location (containing width, heigh, num_examples, and list of classes if class-conditional).
  • Renaming arg --load_from_pkl => --load_from_ckpt
  • Added --num_steps argument to specify a fixed number of steps to run
  • Added --early_stopping_after_steps argument to stop after n steps of no FID improvement
  • Removal of --bf16 flag and consolidation with --mixed_precision.
  • Allow layer freezing with --freeze_g and --freeze_d arguments
  • Add --fmap_max argument, in order to have better control over feature map dimensions
  • Allow disabling of generator and discriminator regularization
  • Change checkpointing behaviour from saving every 2k steps to saving every 10k steps and keeping 2 best checkpoints (see --save_every and --keep_n_checkpoints)
  • Add --metric_cache_location in order to cache dataset statistics (currently for FID only)
  • Log TPU memory usage, shoutout to ayaka14732 for help (see also https://github.com/ayaka14732/jax-smi)
  • Visualise model architecture & parameters on startup
  • Improve W&B logging (e.g. adding eval snapshots with fixed latents)
  • Experimental: Add jax profiling
v0.1
  • Enable training on TPUs
  • Google Cloud Storage (GCS) integration
  • Several quality-of-life improvements

🧑‍🔧 Install

  1. Clone the repository:
    git clone https://github.com/nyx-ai/stylegan2-flax-tpu.git
  2. Go into the directory:
    cd stylegan2-flax-tpu
  3. Install requirements:
    pip install -r requirements.txt

🖼 Generate Images

We released four 256x256 as well as 512x512 models. Download them from the latest release.

python generate_images.py \
   --checkpoint checkpoints/cookie-256.pkl \
   --seeds 0 42 420 666 \
   --truncation_psi 0.7 \
   --out_path generated_images

Check the Colab notebook for more examples: Open In Colab

⚙️ Train Custom Models

Add your images into a folder /path/to/image_dir:

/path/to/image_dir/
    0.jpg
    1.jpg
    2.jpg
    4.jpg
    ...

and create a TFRecord dataset:

python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord

For more detailed instructions please refer to this README.

The following command trains with 128 resolution and batch size of 8.

python main.py --data_dir /path/to/tfrecord

Read more about suitable training parameters here.

Our experiments have been run and tested on TPU VMs (generation v2 to v4). At the time of writing Colab is offering an older generation of TPUs. Therefore training (and especially compilation) may be significantly slower. If you still wish to train on Colab, the following may get you started: Open In Colab

🙏 Acknowledgements