/MADAOT

Code for "Margin-aware Adversarial Domain Adaptation with Optimal Transport" paper

Primary LanguagePython

Code for paper Margin-aware Adversarial Domain Adaptation with Optimal Transport

Dependencies:

  • Numpy >= 1.18.1
  • POT >= 0.6.0
  • CVXPY >= 1.0.25
  • MOSEK >= 9.1.9
  • Scikit-Learn >= 0.22.1

Scripts for experiments:

  • Moons:
    • cross validation: cross_valid.py (might take up to 90 minutes to run)
    • testing: postprocessing_cross_valid.py
  • Amazon:
    • Data download link
    • testing: postprocessing_cross_valid_amazon.py (with fixed hyperparameters indicated in the main paper)

Scripts for figures

  • Moons: postprocessing_cross_valid.py
  • Loss function : loss_funcs.py
  • Smooth proxies (supplementary): proxies.py (end of script, to decomment)

Other scripts:

  • Main class for our algorithm: madaot.py
  • Cross validation (supports parallelism): myDA.py
  • Algorithm computing the transport plan at each step (decribed in Blankenship and Falk, 1976): advEmd.py