/pytorch-adda

A PyTorch implementation for Adversarial Discriminative Domain Adaptation

Primary LanguagePythonMIT LicenseMIT

PyTorch-ADDA

A PyTorch implementation for Adversarial Discriminative Domain Adaptation.

Environment

  • Python 3.6
  • PyTorch 0.2.0

Usage

I only test on MNIST -> USPS, you can just run the following command:

python3 main.py

Network

In this experiment, I use three types of network. They are very simple.

  • LeNet encoder

    LeNetEncoder (
      (encoder): Sequential (
        (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
        (1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (2): ReLU ()
        (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
        (4): Dropout2d (p=0.5)
        (5): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (6): ReLU ()
      )
      (fc1): Linear (800 -> 500)
    )
    
  • LeNet classifier

    LeNetClassifier (
      (fc2): Linear (500 -> 10)
    )
    
  • Discriminator

    Discriminator (
      (layer): Sequential (
        (0): Linear (500 -> 500)
        (1): ReLU ()
        (2): Linear (500 -> 500)
        (3): ReLU ()
        (4): Linear (500 -> 2)
        (5): LogSoftmax ()
      )
    )
    

Result

MNIST (Source) USPS (Target)
Source Encoder + Source Classifier 99.140000% 83.978495%
Target Encoder + Source Classifier 97.634409%

Domain Adaptation does work (97% vs 83%).