/stylegan2-flax-tpu

🖼 Training StyleGAN2 on TPUs in JAX

Primary LanguagePython

StyleGAN2 Flax TPU

This implementation is adapted from the stylegan2 codebase by Matthias Wright.

Specifically, the features we've added allow for better scaling of StyleGAN2 training on TPUs:

  • 🏭 Enable data-parallel training on TPU pods (tested on TPU v2 to v4 generations)
  • 💾 Google Cloud Storage (GCS) integration/dataset sharding between workers
  • 🏖 Quality-of-life improvements (e.g. improved W&B logging)

Web Demo

Integrated into Huggingface Spaces 🤗 using Gradio. Try out the Web Demo Hugging Face Spaces

This food does not exist! Click to see more samples 🍪🍰🍣🍹

These Cookies Do Not Exist

🧑‍🔧 Install

  1. Clone the repository:
    git clone https://github.com/nyx-ai/stylegan2-flax-tpu.git
  2. Go into the directory:
    cd stylegan2-flax-tpu
  3. Install Jax according to your platform.
  4. Install requirements:
    pip install -r requirements.txt

🖼 Generate Images

We released four 256x256 pretrained models: cookie, cheesecake, sushi and cocktail. Download them from the latest release.

python generate_images.py \
   --checkpoint checkpoints/cookie-256.pkl \
   --seeds 0 42 420 666 \
   --truncation_psi 0.7 \
   --out_path generated_images

Check the Colab notebook for more examples: Open In Colab

⚙️ Train Custom Models

Add your images into a folder /path/to/image_dir:

/path/to/image_dir/
    0.jpg
    1.jpg
    2.jpg
    4.jpg
    ...

and create a TFRecord dataset:

python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord

For more detailed instructions please refer to this README.

The following command trains with 128 resolution and batch size of 8.

python main.py --data_dir /path/to/tfrecord

Read more about suitable training parameters here.

🙏 Acknowledgements