Dassl is a PyTorch toolbox for domain adaptation and semi-supervised learning. It has a modular design and unified interfaces, allowing fast prototyping and experimentation. With Dassl, a new method can be implemented with only a few lines of code.
You can use Dassl as a library for researching the following problems:
- Domain adaptation
- Domain generalization
- Semi-supervised learning
Dassl has implemented the following papers:
-
Single-source domain adaptation
- Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)
- Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)
- Self-ensembling for visual domain adaptation (ICLR'18)
- Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17)
- Adversarial Discriminative Domain Adaptation (CVPR'17)
- Domain-Adversarial Training of Neural Networks (JMLR'16)
-
Multi-source domain adaptation
-
Domain generalization
-
Semi-supervised learning
- FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
- MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19)
- Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17)
- Semi-supervised Learning by Entropy Minimization (NeurIPS'04)
Dassl supports the following datasets.
-
Domain adaptation
-
Domain generalization
-
Semi-supervised learning
Make sure conda is installed properly.
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/
# Create a conda environment
conda create -n dassl python=3.7
# Activate the environment
conda activate dassl
# Install dependencies
pip install -r requirements.txt
# Install torch and torchvision (select a version that suits your machine)
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
# Install this library (no need to re-build if the source code is modified)
python setup.py develop
Follow the instructions in DATASETS.md to prepare the datasets.
The main interface is implemented in tools/train.py
, which basically does three things:
- Initialize the config with
cfg = setup_cfg(args)
whereargs
contains the command-line input (seetools/train.py
for the list of input arguments). - Instantiate a
trainer
withbuild_trainer(cfg)
which loads the dataset and builds a deep neural network model. - Call
trainer.train()
for training and evaluating the model.
Below we provide an example for training a source-only baseline on the popular domain adaptation dataset -- Office-31,
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31
$DATA
denotes the path to the dataset folder. --dataset-config-file
loads the common setting for the dataset such as image size and model architecture. --config-file
loads the algorithm-specific setting such as hyper-parameters and optimization parameters.
To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to --source-domains
. For instance, to train a source-only baseline on miniDomainNet, one can do
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn
After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.
Testing can be achieved by using --eval-only
, which tells the script to run trainer.test()
. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use model.pth.tar-20
saved at output/source_only_office31/model
, you can do
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31_test \
--eval-only \
--model-dir output/source_only_office31 \
--load-epoch 20
Note that --model-dir
takes as input the directory path which was specified in --output-dir
in the training stage.
A good practice is to go through dassl/engine/trainer.py
to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass TrainerXU
. For domain generalization, the new class can subclass TrainerX
. In particular, TrainerXU
and TrainerX
mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the forward_backward()
method, which performs loss computation and model update. See dassl/enigne/da/source_only.py
for example.
Please cite the following paper if you find Dassl useful to your research.
@article{zhou2020domain,
title={Domain Adaptive Ensemble Learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={arXiv preprint arXiv:2003.07325},
year={2020}
}