/Causal-Distill-XXS

The Codebase for Causal Distillation for Task-Specific Models

Primary LanguagePython

Python 3.7 License CC BY-NC

Causal Distillation for Natural Language Understanding Tasks (DIITO-XXS)

This is an ONGOING research effort. So, don't expect everything to be working. The is an extended implementation of our preprint Causal Distillation for Language Models by applying the method to task-specific models (i.e., the teacher model here is a fine-tuned model). The codebased for the distillation method the distillation interchange intervention training objective (DIITO) can be found here.

We fork our main codebase from the PKD Distillation to ensure a fair comparison.

Release Notes

✅ 02/21/2022 Release this codebase for others who are interested in applying DIITO to task-specific models.

If you experience any issues or have suggestions, please contact me either thourgh the issues page or at wuzhengx@stanford.edu.

Main Contents

Citation

If you use this repository, please cite the following two papers: paper for interchange intervention training, and paper for the our distillation method.

  @article{geiger-etal-2021-iit,
        title={Inducing Causal Structure for Interpretable Neural Networks}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
        year={2021},
        eprint={2112.00826},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

  @article{wu-etal-2021-distill,
        title={Causal Distillation for Language Models}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
        year={2021},
        eprint={2112.02505},
        archivePrefix={arXiv},
        primaryClass={cs.CL}
  }

Requirements

  • Python 3.6 or 3.7 are supported.
  • Pytorch Version: 1.9.0
  • Transfermers Version: 4.11.3
  • Datasets Version: Version: 1.8.0
  • Since we build our codebase off the Huggingface Distillation Interface, please review their doc for requirements.

Distillation

Now, here is an example for you to distill with our causal distillation objective or without,

python KD_training.py \
--task_name SST-2 \
--output_dir data/outputs/KD/SST-2/teacher_12layer/ \
--bert_model bert-base-uncased \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 5 \
--eval_batch_size 32 \
--gradient_accumulation_steps 1 \
--log_interval 10 \
--checkpoint_interval 100 \
--do_train \
--fp16 False \
--student_hidden_layers 6 \
--fc_layer_idx 1,3,5,7,9 \
--kd_model kd \
--alpha 0.7 \
--T 20 \
--is_wandb \
--wandb_metadata wuzhengx:DIITO-XXS \
--neuron_mapping full \
--is_diito \
--interchange_prop 0.3