/DDG

Primary LanguagePythonMIT LicenseMIT

Towards Principled Disentanglement for Domain Generalization

made-with-python License: MIT

DDG is a PyTorch implementation of Towards Principled Disentanglement for Domain Generalization based on DomainBed.

Available datasets

The currently available datasets are:

Send us a PR to add your dataset! Any custom image dataset with folder structure dataset/domain/class/image.xyz is readily usable. While we include some datasets from the WILDS project, please use their official code if you wish to participate in their leaderboard.

Available model selection criteria

Model selection criteria differ in what data is used to choose the best hyper-parameters for a given model:

  • IIDAccuracySelectionMethod: A random subset from the data of the training domains.
  • LeaveOneOutSelectionMethod: A random subset from the data of a held-out (not training, not testing) domain.
  • OracleSelectionMethod: A random subset from the data of the test domain.

Quick start

Download the datasets:

python scripts/download.py \
       --data-dir /my/datasets/path

Train a model:

python train.py\
       --data-dir /my/datasets/path\
       --algorithm ERM\
       --dataset RotatedMNIST

Pretrain the decoder in DDG model:

python train.py\
       --data-dir /my/datasets/path\
       --algorithm DDG\
       --dataset PACS\
       --stage 0

Train the DDG model with pretrained decoder:

python train.py\
       --data-dir /my/datasets/path\
       --algorithm DDG\
       --gen-dir /my/models/model.pkl
       --dataset PACS\
       --stage 1

Citation

If you find this repo useful, please cite:

@misc{zhang2021principled,
      title={Towards Principled Disentanglement for Domain Generalization}, 
      author={Hanlin Zhang and Yi-Fan Zhang and Weiyang Liu and Adrian Weller and Bernhard Schölkopf and Eric P. Xing},
      year={2021},
      eprint={2111.13839},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}