Implementation of Max Sliced Wasserstein Distance in the paper "Generalized Sliced Wasserstein Distances" using PyTorch.
This repo is based on the implementation shared by Emmanuel Fuentes, here I only modified the way of obtaining theta.
To run this demo, please install the required packages by running: pip install -r requirements-dev.txt
You can train this model with 'max' and 'normal' mode, which means using the Maximum Sliced-Wasserstein distance and the normal Sliced-Wasserstein distance, respectively.
To train with 'max' mode please run: python examples/mnist.py --mode 'max' --mode_test 'max'
.
For more informations, please refer to this file: mnist.py