/causal-semantic-generative-model

Codes for Causal Semantic Generative model (CSG), the model proposed in "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS-21)

Primary LanguagePythonMIT LicenseMIT

Learning Causal Semantic Representation for Out-of-Distribution Prediction

This repository is the official implementation of "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS 2021).

Chang Liu <changliu@microsoft.com>, Xinwei Sun, Jindong Wang, Haoyue Tang, Tao Li, Tao Qin, Wei Chen, Tie-Yan Liu.
[Paper & Appendix] [Slides] [Video] [Poster]

Introduction

graphical summary

The work proposes a Causal Semantic Generative model (CSG) for OOD generalization (single-source domain generalization) and domain adaptation. The model is developed following a causal reasoning process, and prediction is made by leveraging the causal invariance principle. Training and prediction algorithms are developed based on variational Bayes with a novel design. Theoretical guarantees on the identifiability of the causal factor and the benefits for OOD prediction are presented.

This codebase implements the CSG methods, and implements or integrates various baselines. Most domain adaptation baselines (except BNM) use the dalib package. The experiment setups on the PACS and VLCS datasets are adopted from the domainbed repository. Authorships are clarified in each file or module.

Requirements

The code requires python version >= 3.6, and is based on PyTorch. To install requirements:

pip install -r requirements.txt

Usage

Folder a-mnist contains scripts to run the experiments on the Shifted-MNIST dataset, and a-imageclef on the ImageCLEF-DA dataset, and a-domainbed on the PACS and VLCS datasets (the prefix a- represents "application").

Go to the respective folder and run the prepare_data.sh or makedata.sh script there to prepare the datasets. Run the run_ood.sh (for OOD generalization methods) and run_da.sh (for domain adaptation methods) scripts to train the models. Evaluation result (accuracy on the test domain) is printed and written to disk with the model and configurations. See the commands in the script files or python3 main.py --help for customized usage or hyperparameter tuning.

Citation

@inproceedings{liu2021learning,
  author = {Liu, Chang and Sun, Xinwei and Wang, Jindong and Tang, Haoyue and Li, Tao and Qin, Tao and Chen, Wei and Liu, Tie-Yan},
  booktitle = {Advances in Neural Information Processing Systems},
  editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan},
  pages = {6155--6170},
  publisher = {Curran Associates, Inc.},
  title = {Learning Causal Semantic Representation for Out-of-Distribution Prediction},
  url = {https://proceedings.neurips.cc/paper/2021/file/310614fca8fb8e5491295336298c340f-Paper.pdf},
  volume = {34},
  year = {2021}
}