DDG is a PyTorch implementation of Towards Principled Disentanglement for Domain Generalization based on DomainBed.
The currently available datasets are:
- RotatedMNIST (Ghifary et al., 2015)
- VLCS (Fang et al., 2013)
- PACS (Li et al., 2017)
- WILDS (Koh et al., 2020) Camelyon17 (Bandi et al., 2019) about tumor detection in tissues
Send us a PR to add your dataset! Any custom image dataset with folder structure dataset/domain/class/image.xyz
is readily usable. While we include some datasets from the WILDS project, please use their official code if you wish to participate in their leaderboard.
Model selection criteria differ in what data is used to choose the best hyper-parameters for a given model:
IIDAccuracySelectionMethod
: A random subset from the data of the training domains.LeaveOneOutSelectionMethod
: A random subset from the data of a held-out (not training, not testing) domain.OracleSelectionMethod
: A random subset from the data of the test domain.
Download the datasets:
python scripts/download.py \
--data-dir /my/datasets/path
Train a model:
python train.py\
--data-dir /my/datasets/path\
--algorithm ERM\
--dataset RotatedMNIST
Pretrain the decoder in DDG model:
python train.py\
--data-dir /my/datasets/path\
--algorithm DDG\
--dataset PACS\
--stage 0
Train the DDG model with pretrained decoder:
python train.py\
--data-dir /my/datasets/path\
--algorithm DDG\
--gen-dir /my/models/model.pkl
--dataset PACS\
--stage 1
If you find this repo useful, please cite:
@misc{zhang2021principled,
title={Towards Principled Disentanglement for Domain Generalization},
author={Hanlin Zhang and Yi-Fan Zhang and Weiyang Liu and Adrian Weller and Bernhard Schölkopf and Eric P. Xing},
year={2021},
eprint={2111.13839},
archivePrefix={arXiv},
primaryClass={cs.LG}
}