/max-sliced-stargan

:stars: Multi-Domain Image-to-Image Translation using StarGAN with Max Sliced Wasserstein Distance.

Primary LanguagePythonMIT LicenseMIT

Max-Sliced StarGAN

This repository contains Pytorch implementation for the Max-Sliced StarGAN proposed in my project Multi-Domain Image-to-Image Translation using StarGAN with Max Sliced Wasserstein Distance.

The Max-Sliced StarGAN combines StarGAN and Max-Sliced Wasserstein distance together to improve the training stability and reduce sample complexity of the orginal StarGAN.

Sample images generated by the Max-Sliced StarGAN are shown here.

Dependencies

Usage

Note: the current version only supports training on the CelebA dataset.

1. Clone the repository

git clone https://github.com/Ziyu0/max-sliced-stargan.git
cd max-sliced-stargan

2. Download the CelebA dataset

bash download.sh celeba

3. Training

Run

python main.py --help

to see all the configurable hyper-parameters.

Training the original StarGAN

cd scripts
bash train_celeba_original.sh

Training the Max-Sliced StarGAN

cd scripts
bash train_celeba_max_sliced.sh 

In script train_celeba_max_sliced.sh, --use_max_sw_loss is set to True to enable the max-sliced Wasserstein distance.

Training the other baseline models

Please refer to the report for the introduction to the baseline models

  • Training the Sliced StarGAN
    cd scripts
    bash train_celeba_sliced.sh
    
    In script train_celeba_sliced.sh, --use_sw_loss is set to True to enable the sliced Wasserstein distance.
  • Training the Sliced StarGAN with feature transformation
    cd scripts
    bash train_celeba_sliced_feat_trans.sh
    
    In the script train_celeba_sliced_feat_trans.sh, both --use_sw_loss and --use_d_feature are set to True so that we can compute the sliced Wasserstein distance based on the feature transformation.

4. Testing

Testing on all images from the test dataset

cd scripts
bash test_celeba_general.sh 

Testing on a small subset of the images from the test dataset

cd scripts
bash test_celeba_small.sh

5. Plot

To generate loss plots for multiple training processes, run

cd scripts
bash create_plots.sh

Here is a sample plot:

Results

The following figure is the facial attribute transfer results for the original StarGAN and the Max-Sliced StarGAN. Four sets of generated images are shown from the top left to the bottom right sections. For each section, the first column shows the input image, and the next three columns show the single attribute transfer results.

Future work

  • Provide supports for training on other datasets, such as the RaFD dataset.
  • Further improve the performance of the Max-Sliced StarGAN.