/Dassl.pytorch

A PyTorch toolbox for domain adaptation and semi-supervised learning.

Primary LanguagePythonMIT LicenseMIT

Dassl

Dassl is a PyTorch toolbox for domain adaptation and semi-supervised learning. It has a modular design and unified interfaces, allowing fast prototyping and experimentation. With Dassl, a new method can be implemented with only a few lines of code.

You can use Dassl as a library for researching the following problems:

  • Domain adaptation
  • Domain generalization
  • Semi-supervised learning

Overview

Dassl has implemented the following papers:

Dassl supports the following datasets.

Get started

Installation

Make sure conda is installed properly.

# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/

# Create a conda environment
conda create -n dassl python=3.7

# Activate the environment
conda activate dassl

# Install dependencies
pip install -r requirements.txt

# Install torch and torchvision (select a version that suits your machine)
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

# Install this library (no need to re-build if the source code is modified)
python setup.py develop

Follow the instructions in DATASETS.md to prepare the datasets.

Training

The main interface is implemented in tools/train.py, which basically does three things:

  1. Initialize the config with cfg = setup_cfg(args) where args contains the command-line input (see tools/train.py for the list of input arguments).
  2. Instantiate a trainer with build_trainer(cfg) which loads the dataset and builds a deep neural network model.
  3. Call trainer.train() for training and evaluating the model.

Below we provide an example for training a source-only baseline on the popular domain adaptation dataset -- Office-31,

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31

$DATA denotes the path to the dataset folder. --dataset-config-file loads the common setting for the dataset such as image size and model architecture. --config-file loads the algorithm-specific setting such as hyper-parameters and optimization parameters.

To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to --source-domains. For instance, to train a source-only baseline on miniDomainNet, one can do

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn

After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.

Test

Testing can be achieved by using --eval-only, which tells the script to run trainer.test(). You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use model.pth.tar-20 saved at output/source_only_office31/model, you can do

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31_test \
--eval-only \
--model-dir output/source_only_office31 \
--load-epoch 20

Note that --model-dir takes as input the directory path which was specified in --output-dir in the training stage.

Write a new trainer

A good practice is to go through dassl/engine/trainer.py to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass TrainerXU. For domain generalization, the new class can subclass TrainerX. In particular, TrainerXU and TrainerX mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the forward_backward() method, which performs loss computation and model update. See dassl/enigne/da/source_only.py for example.

Citation

Please cite the following paper if you find Dassl useful to your research.

@article{zhou2020domain,
  title={Domain Adaptive Ensemble Learning},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  journal={arXiv preprint arXiv:2003.07325},
  year={2020}
}