/swd_pytorch

An unofficial PyTorch implementation for CVPR 2019 work "Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation"

Primary LanguagePythonOtherNOASSERTION

Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation in PyTorch

This is a PyTorch re-implementation of CVPR 2019 paper "Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation" from Apple.

If you find this repository helpful, please consider to cite the original paper.

Introduction

This repository aims to reproduce the results presented in the official repository. Thus, only a basic implementation on intertwining moons 2D dataset is provided here.

Requirements

  • Python 3.x
  • Pytorch
  • matplotlib

This code is tested under Ubuntu 16.04 with Python 3.6 and PyTorch 1.1.0. A GPU is NOT required to run this code.

Running the code

To run the demo with adaptation:

python swd_pytorch.py -mode adapt_swd

To run the demo without adaptation:

python swd_pytorch.py -mode source_only

Interpreting Outputs

Outputs will be saved as png and gif files in the current folder for each mode. The outputs show the source and target samples with the current decision boundary. Blue and red points are source samples of class 0 and 1. Target samples are represented by green points.  

Acknowledgement

ml-cvpr2019-swd (Official implementation in Tensorflow)