/Self-Attention-GAN

Pytorch implementation of Self-Attention Generative Adversarial Networks (SAGAN)

Primary LanguagePython

Self-Attention GAN

Han Zhang, Ian Goodfellow, Dimitris Metaxas and Augustus Odena, "Self-Attention Generative Adversarial Networks." arXiv preprint arXiv:1805.08318 (2018).

Meta overview

This repository provides a PyTorch implementation of SAGAN. Both wgan-gp and wgan-hinge loss are ready, but note that wgan-gp is somehow not compatible with the spectral normalization. Remove all the spectral normalization at the model for the adoption of wgan-gp.

Self-attentions are applied to later two layers of both discriminator and generator.

Current update status

  • Supervised setting
  • Tensorboard loggings
  • [20180608] updated the self-attention module. Thanks to my colleague Cheonbok Park! see 'sagan_models.py' for the update. Should be efficient, and run on large sized images
  • Attention visualization (LSUN Church-outdoor)
  • Unsupervised setting (use no label yet)
  • Applied: Spectral Normalization, code from here
  • Implemented: self-attention module, two-timescale update rule (TTUR), wgan-hinge loss, wgan-gp loss

   

Results

Attention result on LSUN (epoch #8)

Per-pixel attention result of SAGAN on LSUN church-outdoor dataset. It shows that unsupervised training of self-attention module still works, although it is not interpretable with the attention map itself. Better results with regard to the generated images will be added. These are the visualization of self-attention in generator layer3 and layer4, which are in the size of 16 x 16 and 32 x 32 respectively, each for 64 images. To visualize the per-pixel attentions, only a number of pixels are chosen, as shown on the leftmost and the rightmost numbers indicate.

CelebA dataset (epoch on the left, still under training)

LSUN church-outdoor dataset (epoch on the left, still under training)

Prerequisites

 

Usage

1. Clone the repository

$ git clone https://github.com/heykeetae/Self-Attention-GAN.git
$ cd Self-Attention-GAN

2. Install datasets (CelebA or LSUN)

$ bash download.sh CelebA
or
$ bash download.sh LSUN

3. Train

(i) Train
$ python python main.py --batch_size 64 --imsize 64 --dataset celeb --adv_loss hinge --version sagan_celeb
or
$ python python main.py --batch_size 64 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun

4. Enjoy the results

$ cd samples/sagan_celeb
or
$ cd samples/sagan_lsun

Samples generated every 100 iterations are located. The rate of sampling could be controlled via --sample_step (ex, --sample_step 100).