/CDTrans

[ICLR2022] CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

Primary LanguagePythonMIT LicenseMIT

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

Introduction

This is the official code of CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation.

framework

Results

Table 1 [UDA results on Office-31]

MethodsAvg. A->DA->WD->AD->WW->AW->D
Baseline(DeiT-S)86.7 87.686.974.997.773.599.6
model model model
CDTrans(DeiT-S)90.4 94.693.578.498.27899.6
model model model model model model
Baseline(DeiT-B)88.8 90.890.476.898.276.4100
model model model
CDTrans(DeiT-B)92.6 9796.781.19981.9100
model model model model model model

Table 2 [UDA results on Office-Home]

Methods Avg. Ar->ClAr->PrAr->ReCl->ArCl->PrCl->Re Pr->ArPr->ClPr->ReRe->ArRe->ClRe->Pr
Baseline(DeiT-S) 69.8 55.67379.470.672.976.3 67.5518174.553.282.7
model model model model
CDTrans(DeiT-S)74.7 60.679.582.475.681.082.3 72.556.784.477.059.185.5
model model model model model model model model model model model model
Baseline(DeiT-B)74.861.879.584.375.4 78.881.272.855.784.478.359.386
model model model model
CDTrans(DeiT-B) 80.5 68.88586.981.587.187.3 79.663.388.2826690.6
model model model model model model model model model model model model

Table 3 [UDA results on VisDA-2017]

Methods Per-class planebcyclbuscarhorseknife mcyclpersonplantsktbrdtraintruck
Baseline(DeiT-B) 67.3 (model) 98.148.184.665.276.359.4 94.511.889.552.294.534.1
CDTrans(DeiT-B) 88.4 (model) 97.786.39 86.8783.3397.7697.16 95.9384.0897.9383.4794.5955.3

Table 4 [UDA results on DomainNet]

Base-SclpinfopntqdrrelsktAvg. CDTrans-SclpinfopntqdrrelsktAvg.
clp - 21.2 44.2 15.3 59.9 46.0 37.3 clp - 25.3 52.5 23.2 68.3 53.2 44.5
model model model model model model model
info 36.8 - 39.4 5.4 52.1 32.6 33.3 info 47.6 - 48.3 9.9 62.8 41.1 41.9
model model model model model model model
pnt 47.1 21.7 - 5.7 60.2 39.9 34.9 pnt 55.4 24.5 - 11.7 67.4 48.0 41.4
model model model model model model model
qdr 25.0 3.3 10.4 -18.8 14.0 14.3 qdr 36.6 5.3 19.3 -33.8 22.7 23.5
model model model model model model model
rel 54.8 23.9 52.6 7.4 - 40.1 35.8 rel 61.5 28.1 56.8 12.8 - 47.2 41.3
model model model model model model model
skt 55.6 18.6 42.7 14.9 55.7 - 37.5 skt 64.3 26.1 53.2 23.9 66.2 - 46.7
model model model model model model model
Avg.43.9 17.7 37.9 9.7 49.3 34.5 32.2 Avg.53.08 21.86 46.02 16.3 59.7 42.44 39.9
Base-BclpinfopntqdrrelsktAvg. CDTrans-BclpinfopntqdrrelsktAvg.
clp - 24.2 48.9 15.5 63.9 50.7 40.6 clp - 29.4 57.2 26.0 72.6 58.1 48.7
model model model model model model model
info 43.5 - 44.9 6.5 58.8 37.6 38.3 info 57.0 - 54.4 12.8 69.5 48.4 48.4
model model model model model model model
pnt 52.8 23.3 - 6.6 64.6 44.5 38.4 pnt 62.9 27.4 - 15.8 72.1 53.9 46.4
model model model model model model model
qdr 31.8 6.1 15.6 -23.4 18.9 19.2 qdr 44.6 8.9 29.0 -42.6 28.5 30.7
model model model model model model model
rel 58.9 26.3 56.7 9.1 - 45.0 39.2 rel 66.2 31.0 61.5 16.2 - 52.9 45.6
model model model model model model model
skt 60.0 21.1 48.4 16.6 61.7 - 41.6 skt 69.0 29.6 59.0 27.2 72.5 - 51.5
model model model model model model model
Avg.49.4 20.2 42.9 10.9 54.5 39.3 36.2 Avg.59.9 25.3 52.2 19.6 65.9 48.4 45.2

Requirements

Installation

pip install -r requirements.txt
(Python version is the 3.7 and the GPU is the V100 with cuda 10.1, cudatoolkit 10.1)

Prepare Datasets

Download the UDA datasets Office-31, Office-Home, VisDA-2017, DomainNet

Then unzip them and rename them under the directory like follow: (Note that each dataset folder needs to make sure that it contains the txt file that contain the path and lable of the picture, which is already in data/the_dataset of this project. The 'Real World' domain directory name of the Office-Home should be renamed to 'Real_World' for dataset loading. Otherwise you may encounter "FileNotFoundError: [Errno 2] No such file or directory" )

data
├── OfficeHomeDataset
│   │── class_name
│   │   └── images
│   └── *.txt
├── domainnet
│   │── class_name
│   │   └── images
│   └── *.txt
├── office31
│   │── class_name
│   │   └── images
│   └── *.txt
├── visda
│   │── train
│   │   │── class_name
│   │   │   └── images
│   │   └── *.txt 
│   └── validation
│       │── class_name
│       │   └── images
│       └── *.txt 

Prepare DeiT-trained Models

For fair comparison in the pre-training data set, we use the DeiT parameter init our model based on ViT. You need to download the ImageNet pretrained transformer model : DeiT-Small, DeiT-Base and move them to the ./data/pretrainModel directory.

Training

We utilize 1 GPU for pre-training and 2 GPUs for UDA, each with 16G of memory.

Scripts.

Command input paradigm

bash scripts/[pretrain/uda]/[office31/officehome/visda/domainnet]/run_*.sh [deit_base/deit_small]

For example

DeiT-Base scripts

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_base
bash scripts/uda/office31/run_office_amazon.sh deit_base

#Office-Home    Source: Art      ->  Target: Clipart, Product, Real_World
bash scripts/pretrain/officehome/run_officehome_Ar.sh deit_base
bash scripts/uda/officehome/run_officehome_Ar.sh deit_base

# VisDA-2017    Source: train    ->  Target: validation
bash scripts/pretrain/visda/run_visda.sh deit_base
bash scripts/uda/visda/run_visda.sh deit_base

# DomainNet     Source: Clipart  ->  Target: painting, quickdraw, real, sketch, infograph
bash scripts/pretrain/domainnet/run_domainnet_clp.sh deit_base
bash scripts/uda/domainnet/run_domainnet_clp.sh deit_base

DeiT-Small scripts Replace deit_base with deit_small to run DeiT-Small results. An example of training on office-31 is as follows:

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_small
bash scripts/uda/office31/run_office_amazon.sh deit_small

Evaluation

# For example VisDA-2017
python test.py --config_file 'configs/uda.yml' MODEL.DEVICE_ID "('0')" TEST.WEIGHT "('../logs/uda/vit_base/visda/transformer_best_model.pth')" DATASETS.NAMES 'VisDA' DATASETS.NAMES2 'VisDA' OUTPUT_DIR '../logs/uda/vit_base/visda/' DATASETS.ROOT_TRAIN_DIR './data/visda/train/train_image_list.txt' DATASETS.ROOT_TRAIN_DIR2 './data/visda/train/train_image_list.txt' DATASETS.ROOT_TEST_DIR './data/visda/validation/valid_image_list.txt'  

Acknowledgement

Codebase from TransReID