/hans-forgetting

Primary LanguagePythonMIT LicenseMIT

This is the code for our paper 'Increasing Robustness to Spurious Correlations using Forgettable Examples'.

Reproducing MNLI -> MNLI & HANS results in the paper (one seed)

# download fever data
$ sh getdata.sh mnli && export MNLI_PATH=mnli/MNLI/

# fine-tune a bert base model on mnli 
$ python exp_cli.py train_mnli_bert_base --output_dir mnli_bert_base/

# fine-tune the model on bow forgettables
$ python exp_cli.py finetune_hard_examples mnli_bert_base/checkpoint-epoch-3/ mnli_bert_base_fbow/ --training-examples-ids example_ids/mnli/bow/balanced_forg.ids --task mnli 

# fine-tune the model on lstm forgettables
$ python exp_cli.py finetune_hard_examples mnli_bert_base/checkpoint-epoch-3/ mnli_bert_base_flstm/ --training-examples-ids example_ids/mnli/lstm/balanced_forg.ids --task mnli 

# fine-tune the model on bert forgettables
$ python exp_cli.py finetune_hard_examples mnli_bert_base/checkpoint-epoch-3/ mnli_bert_base_fbert/ --training-examples-ids example_ids/mnli/bert/balanced_forg.ids --task mnli 

To generate the BoW forgettables for MNLI, you can run:

# download glove
$ sh getdata.sh glove

# create embeddings for base weak models (bow, lstm)
$ python exp_cli.py extract_subset_from_glove glove.42B.300d.txt config/dictionary.txt config/

# train bow model
$ python exp_cli.py train_mnli_bow --output_dir mnli_bow

# extract forgettables from a bow model
$ python exp_cli.py extract_hard_examples mnli_bow/ --train_path $MNLI_PATH/train.tsv --task mnli 

and then you can fine-tune your mnli_bert_base checkpoint on your BoW forgettables using:

$ python exp_cli.py finetune_hard_examples mnli_bert_base/checkpoint-epoch-3/ mnli_bert_base_fbow/ --hard_path mnli_bow/hard_examples.pkl