This repo contains the source code of our project, "On-Device Domain Generalization," which studies how to improve tiny neural networks' domain generalization (DG) performance, specifically for mobile DG applications. In the paper, we present a systematic study from which we find that knolwedge distillation outperforms commonly-used DG methods by a large margin under the on-device DG setting. We further propose a simple idea, called out-of-distribution knolwedge distillation (OKD), which extends KD by teaching the student how the teacher handles out-of-distribution data synthesized via data augmentations. We also provide a new suite of DG datasets, named DOSCO-2k, which are built on top of existing vision datasets (much more diverse than existing DG datasets) by synthesizing contextual domain shift using a neural network pretrained on the Places dataset.
- [Oct 2022] Release of source code.
This code is built on top of the awesome toolbox, Dassl.pytorch, so you need to install the dassl
environment first. Simply follow the instructions described here to install dassl
as well as PyTorch. After that, run pip install -r requirements.txt
under on-device-dg/
to install a few more packages (remember to activate the dassl
environment via conda activate dassl
before installing the new packages).
We suggest you download and put all datasets under the same folder, e.g., on-device-dg/data/
.
- PACS & OfficeHome: These two datasets are small (both around 200MB) so we suggest you directly run the code, which will automatically download and preprocess the datasets.
- DOSCO-2k: All datasets from the DOSCO benchmark can be downloaded automatically once you run the code (like PACS and OfficeHome). But we suggest you manually download them first. They can be downloaded from this google drive link.
- Pretrained teacher models (ResNet50): The pretrained ERM models based on ResNet50, i.e., KD's teacher as reported in the paper, can be downloaded here. Please download and extract the file under
on-device-dg/
. To reproduce the results of KD and OKD, you should use these pretrained teacher models. - PlacesViT: The model weights can be downloaded here. Please put the weights under
on-device-dg/tools/
. The feature extraction code is provided inon-device-dg/tools/featext.py
.
The running scripts are provided in on-device-dg/scripts/
:
generic.sh
: This can fit most trainers likeVanilla
.kd.sh
: This is used for those KD-based trainers inon-device-dg/trainers/
(except OKD).okd.sh
This is used for OKD, which mainly differs fromkd.sh
in theAug
argument (it chooses which augmentation method to use for the OOD data generator).
The DATA_ROOT
argument is set to ./data/
by default. Feel free to change the path.
Below are the example commands used to reproduce the results on DOSCO-2k's P-Air using MobileNetV3-Small (should be run under on-device-dg/
):
- ERM:
bash scripts/generic.sh Vanilla p_air mobilenet_v3_small 2k
- RSC:
bash scripts/generic.sh RSC p_air mobilenet_v3_small 2k
- MixStyle:
bash scripts/generic.sh Vanilla p_air mobilenet_v3_small_ms_l12.yaml 2k
- EFDMix:
bash scripts/generic.sh Vanilla p_air mobilenet_v3_small_efdmix_l12.yaml 2k
- KD:
bash scripts/kd.sh KD p_air mobilenet_v3_small 2k
- OKD:
bash scripts/okd.sh OKD fusion p_air mobilenet_v3_small 2k
Some notes:
- MixStyle and EFDMix use the same trainer as ERM, i.e.,
Vanilla
. - To use a different dataset, simply change
p_air
. Note that the dataset names should match the file names inon-device-dg/configs/datasets/
, such asp_cars
forP-Cars
andp_ctech
forP-Ctech
. - To use a different architecture like MobileNetV2-Tiny or MCUNet studied in the paper, simply change
mobilenet_v3_small
tomobilenet_v2_tiny
ormcunet
. (The model names should match the file names inon-device-dg/configs/hparam
.) - To reproduce the results on PACS and OfficeHome, you need to (i) change
p_air
topacs
oroh
, (ii) change2k
tofull
, and (iii) add an index number from{1, 2, 3, 4}
at the end of the argument list. Say you want to run OKD on PACS, which has four settings (each using one of the four domains as the test domain), the command template isbash scripts/okd.sh OKD fusion pacs mobilenet_v3_small 2k {TIDX}
whereTIDX = 1/2/3/4
. - After you obtain the results of three seeds, you can use
parse_test_res.py
to automatically compute the average results. You can give a quick try: say you have downloaded the pretrained teacher models aton-device-dg/pretrained
, runpython parse_test_res.py pretrained/Vanilla/p_air/env_2k/resnet50/
to get the average results for the P-Air dataset (basically../resnet50/
should contain three seed folders each containing alog.txt
file). Note that for PACS and OfficeHome, the../resnet50/
folder contains four sets of results each corresponding to a test domain, you need to usepython parse_test_res.py pretrained/Vanilla/pacs/env_full/resnet50/ --multi-exp
.
@article{zhou2022device,
title={On-Device Domain Generalization},
author={Zhou, Kaiyang and Zhang, Yuanhan and Zang, Yuhang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
journal={arXiv preprint arXiv:2209.07521},
year={2022}
}