/Causal-Proxy-Model

The Codebase for Causal Proxy Model

Primary LanguagePythonMIT LicenseMIT

Python 3.7 License CC BY-NC

Causal Proxy Models For Concept-Based Model Explanations

Zhengxuan Wu*, Karel D'Oosterlinck*, Atticus Geiger*, Amir Zur, Christopher Potts

The codebase contains some implementations of our preprint Causal Proxy Models For Concept-Based Model Explanations. In this paper, we introuce two variants of CPM,

  • CPMIN: Input-base CPM uses auxiliary token to represent the intervention, and is trained in a supervised way of predicting counterfactual output. This model is built on an input-level intervention.
  • CPMHI: Hidden-state CPM uses Interchange Intervention Training (IIT) to localize concept information within its representations, and swaps hidden-states to represent the intervention. It is trained in a supervised way of predicting counterfactual output. This model is built on a hidden-state intervention.

This codebase contains implementations and experiments for both CPMIN and CPMHI. If you experience any issues or have suggestions, please contact me either thourgh the issues page or at wuzhengx@cs.stanford.edu or at karel.doosterlinck@ugent.be.

Citation

If you use this repository, please consider to cite our relevant papers:

  @article{wu-etal-2021-cpm,
        title={Causal Proxy Models For Concept-Based Model Explanations}, 
        author={Wu, Zhengxuan and D'Oosterlinck, Karel and Geiger, Atticus and Zur, Amir and Potts, Christopher},
        year={2022},
        eprint={2209.14279},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

  @article{geiger-etal-2021-iit,
        title={Inducing Causal Structure for Interpretable Neural Networks}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
        year={2021},
        eprint={2112.00826},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

Requirements

  • Python 3.6 or 3.7 are supported.
  • Pytorch Version: 1.11.0
  • Transfermers Version: 4.21.1
  • Datasets Version: Version: 2.3.2

Installation

First clone the directory. Then run the following command to initialize the submodules:

git submodule init; git submodule update

Loading Black-box Models for CEBaB

These models are avaliable from the CEBaB website. Here is one example about how to load these models!

from transformers import AutoTokenizer, BertForNonlinearSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42")

model = BertForNonlinearSequenceClassification.from_pretrained("CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42")

Loading CPMs for CEBaB

We aim to make all of our CPMs public. Currently, they are be found on our huggingface repo.

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("CPMs/cpm.hi.bert-base-uncased.layer.10.size.192")

model = AutoModelForSequenceClassification.from_pretrained("CPMs/cpm.hi.bert-base-uncased.layer.10.size.192")

Note that we also have different helpers to load these models into our explainer module. Please refer to notebooks under experiments folder.

Training CPMIN

To train CPMIN, we follow the basic finetuning setup since the intervention is on the inputs. To train, you should first go to CEBaB-inclusive/eval_pipeline/; and you can run the following command to train.

python main.py \
--model_architecture bert-base-uncased \
--train_setting inclusive \
--model_output_dir model_output \
--output_dir output \
--flush_cache true \
--task_name opentable_5_way \
--batch_size 128 \
--k_array 19684

To train with different variants of approximate counterfactuals, you need to change the flag --train_setting approximate for metadata-sampled counterfactuals. Note that in this setting, you can ignore the field --k_array. You should change --model_architecture for different model architectures.

Training CPMHI

To train CPMHI, we adapt interchange intervention training (IIT). To train, you can use the following command, and you can refer to our paper for configurations.

python Proxy_training.py \
--model_name_or_path ./saved_models/bert-base-uncased.opentable.CEBaB.sa.5-class.exclusive.seed_42/ \
--task_name CEBaB \
--dataset_name CEBaB/CEBaB \
--do_train \
--per_device_train_batch_size 256 \
--per_device_eval_batch_size 256 \
--learning_rate 8e-05 \
--output_dir ./proxy_training_results/your_first_try/ \
--cache_dir ./train_cache/ \
--seed 42 \
--report_to none \
--logging_steps 1 \
--alpha 1.0 \
--beta 1.0 \
--gemma 3.0 \
--overwrite_output_dir \
--intervention_h_dim 192 \
--counterfactual_type true \
--k 19684 \
--interchange_hidden_layer 10 \
--save_steps 10 \
--early_stopping_patience 20

To train with different variants of approximate counterfactuals, you need to change the flag --counterfactual_type approximate for metadata-sampled counterfactuals. Note that in this setting, you can ignore the field --k. You should change --model_name_or_path for different model architectures. These models can be downloaded from CEBaB website.