This repository provides the code to reproduce the experimental results in the paper Augmented Sliced Wasserstein Distances.
To install the required python packages, run the following command:
pip install -r requirements.txt
Two datasets are used in this repository, namely the CIFAR10 dataset and CELEBA dataset.
- The CIFAR10 dataset (64x64 pixels) will be automatically downloaded from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz when running the experiment on CIFAR10 dataset.
- The CELEBA dataset needs be be manually downloaded and can be found on the website http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, we use the cropped CELEBA dataset with 64x64 pixels.
To calculate the Fréchet Inception Distance (FID score), precalculated statistics for datasets
- CIFAR 10 (calculated on all training samples)
- cropped CelebA (64x64, calculated on all samples)
are provided at: http://bioinf.jku.at/research/ttur/.
Two experiments are included in this repository, where benchmarks are from the paper Generalized Sliced Wasserstein Distances and the paper Distributional Sliced-Wasserstein and Applications to Generative Modeling, respectively. The first one is on the task of sliced Wasserstein flow, and the second one is on generative modellings with GANs. For more details and setups, please refer to the original paper Augmented Sliced Wasserstein Distances.
./result/ASWD/CIFAR/
contains generated imgaes trained with the ASWD on CIFAR10 dataset../result/ASWD/CIFAR/fid/
FID scores of generated imgaes trained with the ASWD on CIFAR10 dataset are saved in this folder../result/CIFAR/
model's weights and losses in the CIFAR10 experiment are stored in this directory.
Other setups follow the same naming rule.
The sliced Wasserstein flow example can be found in the jupyter notebook.
The following scripts belong to the generative modelling example:
- main.py : run this file to conduct experiments.
- utils.py : contains implementations of different sliced-based Wasserstein distances.
- TransformNet.py : edit this file to modify architectures of neural networks used to map samples.
- experiments.py : functions for generating and saving randomly generated images.
- DCGANAE.py : neural network architectures and optimization objective for training GANs.
- fid_score.py : functions for calculating statistics (mean & covariance matrix) of distributions of images and the FID score between two distributions of images.
- inception.py : download the pretrained InceptionV3 model and generate feature maps for FID evaluation.
The generative modelling experiment evaluates the performances of GANs trained with different sliced-based Wasserstein metrics. To train and evaluate the model, run the following command:
python main.py --model-type ASWD --dataset CIFAR --epochs 200 --num-projection 1000 --batch-size 512 --lr 0.0005
--model-type
type of sliced-based Wasserstein metric used in the experiment, available options: ASWD, DSWD, SWD, MSWD, GSWD. Must be specified.--dataset
select from: CIFAR, CELEBA, default as CIFAR.--epochs
training epochs, default as 200.--num-projection
number of projections used in distance approximation, default as 1000.--batch-size
batch size for one iteration, default as 512.--lr
learning rate, default as 0.0005.
--niter
number of iteration, available for the ASWD, MSWD and DSWD, default as 5.--lam
coefficient of regularization term, available for the ASWD and DSWD, default as 0.5.--r
parameter in the circular defining function, available for GSWD, default as 1000.
The code of generative modelling example is based on the implementation of DSWD by VinAI Research.
The pytorch code for calculating the FID score is from https://github.com/mseitzer/pytorch-fid.