Dassl is a PyTorch toolbox initially developed for our project Domain Adaptive Ensemble Learning (DAEL) to support research in domain adaptation and generalization---since in DAEL we study how to unify these two problems in a single learning framework. Given that domain adaptation is closely related to semi-supervised learning---both study how to exploit unlabeled data---we also incorporate components that support research for the latter.
Why the name "Dassl"? Dassl combines the initials of domain adaptation (DA) and semi-supervised learning (SSL), which sounds natural and informative.
Dassl has a modular design and unified interfaces, allowing fast prototyping and experimentation of new DA/DG/SSL methods. With Dassl, a new method can be implemented with only a few lines of code. Don't believe? Take a look at the engine folder, which contains the implementations of many existing methods (then you will come back and star this repo). :-)
Basically, Dassl is perfect for doing research in the following areas:
- Domain adaptation
- Domain generalization
- Semi-supervised learning
BUT, thanks to the neat design, Dassl can also be used as a codebase to develop any deep learning projects. :-)
We don't provide detailed documentations for Dassl, unlike another project of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-)
- [Aug 2021]:
v0.4.0
: The most noteworthy update is adding the learning rate warmup scheduler. The implementation is detailed here and the config variables are specified here. - [Jul 2021]:
v0.3.4
: Adds a new functiongenerate_fewshot_dataset()
to the base dataset class, which allows for the generation of a few-shot learning setting. One can customize a few-shot dataset by specifying_C.DATASET.NUM_SHOTS
and give it togenerate_fewshot_dataset()
. - [Jul 2021]:
v0.3.2
: Adds_C.INPUT.INTERPOLATION
(default:bilinear
). Available interpolation modes arebilinear
,nearest
, andbicubic
. - [Jul 2021]
v0.3.1
: Now you can use*.register(force=True)
to replace previously registered modules. - [Jul 2021]
v0.3.0
: Allows to deploy the model with the best validation performance for final test (for the purpose of model selection). Specifically, a new config variable named_C.TEST.FINAL_MODEL
is introduced, which takes either"last_step"
(default) or"best_val"
. When set to"best_val"
, the model will be evaluated on theval
set after each epoch and the one with the best validation performance will be saved and used for final test (see this code). - [Jul 2021]
v0.2.7
: Adds attributeclassnames
to the base dataset class. Now you can get a list of class names ordered by numeric labels by callingtrainer.dm.dataset.classnames
. - [Jun 2021]
v0.2.6
: MergesMixStyle2
toMixStyle
. A new variableself.mix
is used to switch between random mixing and cross-domain mixing. Please see this for more details on the new features. - [Jun 2021]
v0.2.5
: Fixs a bug in the calculation of per-class recognition accuracy. - [Jun 2021]
v0.2.4
: Addsextend_cfg(cfg)
totrain.py
. This function is particularly useful when you build your own methods on top of Dassl.pytorch and need to define some custom variables. Please see the repository mixstyle-release or ssdg-benchmark for examples. - [Jun 2021] New benchmarks for semi-supervised domain generalization at https://github.com/KaiyangZhou/ssdg-benchmark.
- [Apr 2021] Do you know you can use
tools/parse_test_res.py
to read the log files and automatically calculate and print out the results including mean and standard deviation? Check the instructions intools/parse_test_res.py
for more details.
More
- [Apr 2021]
v0.2.3
: A MixStyle layer can now be deactivated or activated by usingmodel.apply(deactivate_mixstyle)
ormodel.apply(activate_mixstyle)
without modifying the source code. See dassl/modeling/ops/mixstyle.py for the details. - [Apr 2021]
v0.2.2
: AddsRandomClassSampler
, which samples from a certain number of classes a certain number of images to form a minibatch (the code is modified from Torchreid). - [Apr 2021]
v0.2.1
: Slightly adjusts the ordering insetup_cfg()
(seetools/train.py
). - [Apr 2021]
v0.2.0
: Adds_C.DATASET.ALL_AS_UNLABELED
(for the SSL setting) to the config variable list. When this variable is set toTrue
, all labeled data will be included in the unlabeled data set. - [Apr 2021]
v0.1.9
: Adds VLCS to the benchmark datasets (seedassl/data/datasets/dg/vlcs.py
). - [Mar 2021]
v0.1.8
: Allowsoptim
andsched
to beNone
inregister_model()
. - [Mar 2021]
v0.1.7
: Adds MixStyle models to dassl/modeling/backbone/resnet.py. The training configs inconfigs/trainers/dg/vanilla
can be used to train MixStyle models. - [Mar 2021]
v0.1.6
: Adds CIFAR-10/100-C to the benchmark datasets for evaluating a model's robustness to image corruptions. - [Mar 2021] We have just released a survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in this topic with coverage on the history, related problems, datasets, methodologies, potential directions, and so on.
- [Jan 2021] Our recent work, MixStyle (mixing instance-level feature statistics of samples of different domains for improving domain generalization), is accepted to ICLR'21. The code is available at https://github.com/KaiyangZhou/mixstyle-release where the cross-domain image classification part is based on Dassl.pytorch.
- [May 2020]
v0.1.3
: Adds theDigit-Single
dataset for benchmarking single-source DG methods. The corresponding CNN model is dassl/modeling/backbone/cnn_digitsingle.py and the dataset config file is configs/datasets/dg/digit_single.yaml. See Volpi et al. NIPS'18 for how to do evaluation. - [May 2020]
v0.1.2
: 1) Adds EfficientNet models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, setMODEL.BACKBONE.NAME
toefficientnet_b{N}
whereN={0, ..., 7}
. 2)dassl/modeling/models
is renamed todassl/modeling/network
(build_model()
tobuild_network()
andMODEL_REGISTRY
toNETWORK_RESIGTRY
).
Dassl has implemented the following methods:
-
Single-source domain adaptation
- Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19) [dassl/engine/da/mme.py]
- Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18) [dassl/engine/da/mcd.py]
- Self-ensembling for visual domain adaptation (ICLR'18) [dassl/engine/da/self_ensembling.py]
- Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17) [dassl/engine/da/adabn.py]
- Adversarial Discriminative Domain Adaptation (CVPR'17) [dassl/engine/da/adda.py]
- Domain-Adversarial Training of Neural Networks (JMLR'16) [dassl/engine/da/dann.py]
-
Multi-source domain adaptation
-
Domain generalization
-
Semi-supervised learning
- FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence [dassl/engine/ssl/fixmatch.py]
- MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19) [dassl/engine/ssl/mixmatch.py]
- Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17) [dassl/engine/ssl/mean_teacher.py]
- Semi-supervised Learning by Entropy Minimization (NeurIPS'04) [dassl/engine/ssl/entmin.py]
Feel free to make a PR to add your methods here to make it easier for others to benchmark!
Dassl supports the following datasets:
-
Domain adaptation
-
Domain generalization
-
Semi-supervised learning
Make sure conda is installed properly.
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/
# Create a conda environment
conda create -n dassl python=3.7
# Activate the environment
conda activate dassl
# Install dependencies
pip install -r requirements.txt
# Install torch (version >= 1.7.1) and torchvision
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
# Install this library (no need to re-build if the source code is modified)
python setup.py develop
Follow the instructions in DATASETS.md to preprocess the datasets.
The main interface is implemented in tools/train.py
, which basically does
- initialize the config with
cfg = setup_cfg(args)
whereargs
contains the command-line input (seetools/train.py
for the list of input arguments); - instantiate a
trainer
withbuild_trainer(cfg)
which loads the dataset and builds a deep neural network model; - call
trainer.train()
for training and evaluating the model.
Below we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31,
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31
$DATA
denotes the location where datasets are installed. --dataset-config-file
loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. --config-file
loads the algorithm-specific setting such as hyper-parameters and optimization parameters.
To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to --source-domains
. For instance, to train a source-only baseline on miniDomainNet, one can do
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn
After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.
To print out the results saved in the log file (so you do not need to exhaustively go through all log files and calculate the mean/std by yourself), you can use tools/parse_test_res.py
. The instruction can be found in the code.
For other trainers such as MCD
, you can set --trainer MCD
while keeping the config file unchanged, i.e. using the same training parameters as SourceOnly
(in the simplest case). To modify the hyper-parameters in MCD, like N_STEP_F
(number of steps to update the feature extractor), you can append TRAINER.MCD.N_STEP_F 4
to the existing input arguments (otherwise the default value will be used). Alternatively, you can create a new .yaml
config file to store your custom setting. See here for a complete list of algorithm-specific hyper-parameters.
Model testing can be done by using --eval-only
, which asks the code to run trainer.test()
. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use model.pth.tar-20
saved at output/source_only_office31/model
, you can do
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31_test \
--eval-only \
--model-dir output/source_only_office31 \
--load-epoch 20
Note that --model-dir
takes as input the directory path which was specified in --output-dir
in the training stage.
A good practice is to go through dassl/engine/trainer.py
to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass TrainerXU
. For domain generalization, the new class can subclass TrainerX
. In particular, TrainerXU
and TrainerX
mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the forward_backward()
method, which performs loss computation and model update. See dassl/enigne/da/source_only.py
for example.
backbone
corresponds to a convolutional neural network model which performs feature extraction. head
(which is an optional module) is mounted on top of backbone
for further processing, which can be, for example, a MLP. backbone
and head
are basic building blocks for constructing a SimpleNet()
(see dassl/engine/trainer.py
) which serves as the primary model for a task. network
contains custom neural network models, such as an image generator.
To add a new module, namely a backbone/head/network, you need to first register the module using the corresponding registry
, i.e. BACKBONE_REGISTRY
for backbone
, HEAD_REGISTRY
for head
and NETWORK_RESIGTRY
for network
. Note that for a new backbone
, we require the model to subclass Backbone
as defined in dassl/modeling/backbone/backbone.py
and specify the self._out_features
attribute.
We provide an example below for how to add a new backbone
.
from dassl.modeling import Backbone, BACKBONE_REGISTRY
class MyBackbone(Backbone):
def __init__(self):
super().__init__()
# Create layers
self.conv = ...
self._out_features = 2048
def forward(self, x):
# Extract and return features
@BACKBONE_REGISTRY.register()
def my_backbone(**kwargs):
return MyBackbone()
Then, you can set MODEL.BACKBONE.NAME
to my_backbone
to use your own architecture. For more details, please refer to the source code in dassl/modeling
.
We would like to share here our research relevant to Dassl.
- MixStyle Neural Networks for Domain Generalization and Adaptation, arxiv preprint, 2021.
- Semi-Supervised Domain Generalization with Stochastic StyleMatch, arxiv preprint, 2021.
- Domain Generalization in Vision: A Survey, arxiv preprint, 2021.
- Domain Generalization with MixStyle, in ICLR 2021.
- Learning to Generate Novel Domains for Domain Generalization, in ECCV 2020.
- Deep Domain-Adversarial Image Generation for Domain Generalisation, in AAAI 2020.
- Domain Adaptive Ensemble Learning, arxiv preprint, 2020.
If you find this code useful to your research, please give credit to the following paper
@article{zhou2020domain,
title={Domain Adaptive Ensemble Learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={arXiv preprint arXiv:2003.07325},
year={2020}
}