/BeamRecursionFamily

Primary LanguagePythonApache License 2.0Apache-2.0

Official Repository For:

1. Efficient Beam Tree Recursion - Jishnu Ray Chowdhury, Cornelia Caragea NeurIPS 2023

2. Recursion in Recursion: Two-Level Nested Recursion for Length Generalization with Scalability - Jishnu Ray Chowdhury, Cornelia Caragea NeurIPS 2023

Credits:

Requirements

  • pytorch 1.10.0
  • pytorch-lightning 1.9.3
  • tqdm 4.62.3
  • tensorflow-datasets 4.5.2
  • typing_extensions 4.5.0
  • pykeops 2.1.1
  • jsonlines 2.0.0
  • einops 0.6.0
  • torchtext 0.8.1

Data Setup

Processing Data

  • Go to preprocess/ and run each preprocess files to preprocess the corresponding data (process_SNLI_addon.py must be run after process_SNLI.py; otherwise no order requirement)

You can verify if the data is properly set up from the tree below.

├───data
│   ├───AAN
│   │       new_aan_pairs.eval.tsv
│   │       new_aan_pairs.test.tsv
│   │       new_aan_pairs.train.tsv
│   │
│   ├───IMDB
│   │   │   dev_contrast.tsv
│   │   │   dev_original.tsv
│   │   │   test_contrast.tsv
│   │   │   test_counterfactual.tsv
│   │   │
│   │   └───aclImdb
│   │       │   imdb.vocab
│   │       │   imdbEr.txt
│   │       │   README
│   │       │
│   │       ├───test
│   │       │   │   labeledBow.feat
│   │       │   │   urls_neg.txt
│   │       │   │   urls_pos.txt
│   │       │   │
│   │       │   ├───neg
│   │       │   └───pos
│   │       └───train
│   │           │   labeledBow.feat
│   │           │   unsupBow.feat
│   │           │   urls_neg.txt
│   │           │   urls_pos.txt
│   │           │   urls_unsup.txt
│   │           │
│   │           ├───neg
│   │           ├───pos
│   │           └───unsup
│   ├───listops
│   │       base.py
│   │       basic_test.tsv
│   │       dev_d7s.tsv
│   │       load_listops_data.py
│   │       make_depth_dev_data.py
│   │       make_depth_ndr_data.py
│   │       make_depth_test_data.py
│   │       make_depth_train_data.py
│   │       make_depth_train_data_extra.py
│   │       make_iid_data.py
│   │       make_odd_25depth_data.py
│   │       make_ood_10arg_data.py
│   │       make_ood_15arg_data.py
│   │       test_200_300.tsv
│   │       test_300_400.tsv
│   │       test_400_500.tsv
│   │       test_500_600.tsv
│   │       test_600_700.tsv
│   │       test_700_800.tsv
│   │       test_800_900.tsv
│   │       test_900_1000.tsv
│   │       test_d20s.tsv
│   │       test_dg8s.tsv
│   │       test_iid_arg.tsv
│   │       test_ood_10arg.tsv
│   │       test_ood_15arg.tsv
│   │       train_d20s.tsv
│   │       train_d6s.tsv
│   │       __init__.py
│   │
│   ├───listops_lra
│   │       basic_test.tsv
│   │       basic_train.tsv
│   │       basic_val.tsv
│   │
│   ├───MNLI
│   │   │   conj_dev.tsv
│   │   │   multinli_0.9_test_matched_unlabeled.jsonl
│   │   │   multinli_0.9_test_matched_unlabeled_hard.jsonl
│   │   │   multinli_0.9_test_mismatched_unlabeled.jsonl
│   │   │   multinli_0.9_test_mismatched_unlabeled.jsonl.zip
│   │   │   multinli_0.9_test_mismatched_unlabeled_hard.jsonl
│   │   │   multinli_1.0_dev_matched.jsonl
│   │   │   multinli_1.0_dev_matched.txt
│   │   │   multinli_1.0_dev_mismatched.jsonl
│   │   │   multinli_1.0_dev_mismatched.txt
│   │   │   multinli_1.0_train.jsonl
│   │   │   multinli_1.0_train.txt
│   │   │   paper.pdf
│   │   │   README.txt
│   │   │
│   │   ├───Antonym
│   │   │       multinli_0.9_antonym_matched.jsonl
│   │   │       multinli_0.9_antonym_matched.txt
│   │   │       multinli_0.9_antonym_mismatched.jsonl
│   │   │       multinli_0.9_antonym_mismatched.txt
│   │   │
│   │   ├───Length_Mismatch
│   │   │       multinli_0.9_length_mismatch_matched.jsonl
│   │   │       multinli_0.9_length_mismatch_matched.txt
│   │   │       multinli_0.9_length_mismatch_mismatched.jsonl
│   │   │       multinli_0.9_length_mismatch_mismatched.txt
│   │   │
│   │   ├───Negation
│   │   │       multinli_0.9_negation_matched.jsonl
│   │   │       multinli_0.9_negation_matched.txt
│   │   │       multinli_0.9_negation_mismatched.jsonl
│   │   │       multinli_0.9_negation_mismatched.txt
│   │   │
│   │   ├───Numerical_Reasoning
│   │   │       .DS_Store
│   │   │       multinli_0.9_quant_hard.jsonl
│   │   │       multinli_0.9_quant_hard.txt
│   │   │
│   │   ├───Spelling_Error
│   │   │       multinli_0.9_dev_gram_contentword_swap_perturbed_matched.jsonl
│   │   │       multinli_0.9_dev_gram_contentword_swap_perturbed_matched.txt
│   │   │       multinli_0.9_dev_gram_contentword_swap_perturbed_mismatched.jsonl
│   │   │       multinli_0.9_dev_gram_contentword_swap_perturbed_mismatched.txt
│   │   │       multinli_0.9_dev_gram_functionword_swap_perturbed_matched.jsonl
│   │   │       multinli_0.9_dev_gram_functionword_swap_perturbed_matched.txt
│   │   │       multinli_0.9_dev_gram_functionword_swap_perturbed_mismatched.jsonl
│   │   │       multinli_0.9_dev_gram_functionword_swap_perturbed_mismatched.txt
│   │   │       multinli_0.9_dev_gram_keyboard_matched.jsonl
│   │   │       multinli_0.9_dev_gram_keyboard_matched.txt
│   │   │       multinli_0.9_dev_gram_keyboard_mismatched.jsonl
│   │   │       multinli_0.9_dev_gram_keyboard_mismatched.txt
│   │   │       multinli_0.9_dev_gram_swap_matched.jsonl
│   │   │       multinli_0.9_dev_gram_swap_matched.txt
│   │   │       multinli_0.9_dev_gram_swap_mismatched.jsonl
│   │   │       multinli_0.9_dev_gram_swap_mismatched.txt
│   │   │
│   │   └───Word_Overlap
│   │           multinli_0.9_taut2_matched.jsonl
│   │           multinli_0.9_taut2_matched.txt
│   │           multinli_0.9_taut2_mismatched.jsonl
│   │           multinli_0.9_taut2_mismatched.txt
│   │
│   ├───proplogic
│   │       generate_neg_set_data.py
│   │       test0
│   │       test1
│   │       test10
│   │       test11
│   │       test12
│   │       test2
│   │       test3
│   │       test4
│   │       test5
│   │       test6
│   │       test7
│   │       test8
│   │       test9
│   │       train0
│   │       train1
│   │       train10
│   │       train11
│   │       train12
│   │       train2
│   │       train3
│   │       train4
│   │       train5
│   │       train6
│   │       train7
│   │       train8
│   │       train9
│   │       __init__.py
│   │
│   ├───QQP
│   │   │   quora_duplicate_questions.tsv
│   │   │
│   │   ├───PAWS_QQP
│   │   │       dev_and_test.tsv
│   │   │
│   │   └───PAWS_WIKI
│   │           test.tsv
│   │
│   └───SNLI
│       │   dataset.jsonl
│       │   README.txt
│       │   snli_1.0_dev.jsonl
│       │   snli_1.0_dev.txt
│       │   snli_1.0_test.jsonl
│       │   snli_1.0_test.txt
│       │   snli_1.0_test_hard.jsonl
│       │   snli_1.0_train.jsonl
│       │   snli_1.0_train.txt
│       │
│       └───revised_combined
│               test.tsv
├───embeddings
│   └───glove
│           glove.840B.300d.txt
├───processed_data
│   ├───AAN_lra
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───IMDB
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_contrast.jsonl
│   │       test_counterfactual.jsonl
│   │       test_normal.jsonl
│   │       test_original_of_contrast.jsonl
│   │       test_original_of_counterfactual.jsonl
│   │       train.jsonl
│   │
│   ├───IMDB_lra
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───listops200speed
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───listops500speed
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───listops900speed
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───listopsc
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_200_300.jsonl
│   │       test_300_400.jsonl
│   │       test_400_500.jsonl
│   │       test_500_600.jsonl
│   │       test_600_700.jsonl
│   │       test_700_800.jsonl
│   │       test_800_900.jsonl
│   │       test_900_1000.jsonl
│   │       test_iid_arg.jsonl
│   │       test_LRA.jsonl
│   │       test_normal.jsonl
│   │       test_ood_10arg.jsonl
│   │       test_ood_15arg.jsonl
│   │       train.jsonl
│   │
│   ├───listopsmix
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_200_300.jsonl
│   │       test_300_400.jsonl
│   │       test_400_500.jsonl
│   │       test_500_600.jsonl
│   │       test_600_700.jsonl
│   │       test_700_800.jsonl
│   │       test_800_900.jsonl
│   │       test_900_1000.jsonl
│   │       test_iid_arg.jsonl
│   │       test_LRA.jsonl
│   │       test_normal.jsonl
│   │       test_ood_10arg.jsonl
│   │       test_ood_15arg.jsonl
│   │       train.jsonl
│   │
│   ├───listops_lra
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───listops_lra_speed
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───MNLIdev
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_antonym_matched.jsonl
│   │       test_antonym_mismatched.jsonl
│   │       test_conj_nli.jsonl
│   │       test_content_word_swap_matched.jsonl
│   │       test_content_word_swap_mismatched.jsonl
│   │       test_function_word_swap_matched.jsonl
│   │       test_function_word_swap_mismatched.jsonl
│   │       test_keyboard_swap_matched.jsonl
│   │       test_keyboard_swap_mismatched.jsonl
│   │       test_length_matched.jsonl
│   │       test_length_mismatched.jsonl
│   │       test_matched.jsonl
│   │       test_mismatched.jsonl
│   │       test_negation_matched.jsonl
│   │       test_negation_mismatched.jsonl
│   │       test_numerical.jsonl
│   │       test_swap_matched.jsonl
│   │       test_swap_mismatched.jsonl
│   │       test_word_overlap_matched.jsonl
│   │       test_word_overlap_mismatched.jsonl
│   │       train.jsonl
│   │
│   ├───proplogic
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_0.jsonl
│   │       test_1.jsonl
│   │       test_10.jsonl
│   │       test_11.jsonl
│   │       test_12.jsonl
│   │       test_2.jsonl
│   │       test_3.jsonl
│   │       test_4.jsonl
│   │       test_5.jsonl
│   │       test_6.jsonl
│   │       test_7.jsonl
│   │       test_8.jsonl
│   │       test_9.jsonl
│   │       train.jsonl
│   │
│   ├───QQP
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       test_PAWS_QQP.jsonl
│   │       test_PAWS_WIKI.jsonl
│   │       train.jsonl
│   │
│   ├───SNLI
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       SNLI_dev_normal.jsonl
│   │       SNLI_metadata.pkl
│   │       SNLI_test_hard.jsonl
│   │       SNLI_test_normal.jsonl
│   │       SNLI_train.jsonl
│   │       test_break.jsonl
│   │       test_counterfactual.jsonl
│   │       test_hard.jsonl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   ├───SST2
│   │       dev_normal.jsonl
│   │       metadata.pkl
│   │       test_normal.jsonl
│   │       train.jsonl
│   │
│   └───SST5
│           dev_normal.jsonl
│           metadata.pkl
│           test_normal.jsonl
│           train.jsonl

How to train

Train: python trian.py --model=[insert model name] -- dataset=[insert dataset name] --times=[insert total runs] --device=[insert device name] --model_type=[classifier/sentence_pair/sentence_pair2]

  • Check argparser.py for exact options.
  • sentence_pair2 is used for sequence interaction models for sequence matching tasks (NLI, paraphrase detection), otherwise sequence_pair is used for model_type (if nothing about sequence interaction is explicitly mentioned in the paper then we are talking about a different paper).
  • Generally we use total times as 3. For LRA we use 2.

Model Nomenclature

The nomenclature in the codebase and in the paper are a bit different. We provide a mapping here of the form ([codebase model name] == [paper model name])

  • CRvNN == CRvNN
  • CRvNN_nohalt == CRvNN (during stress test)
  • OM == OM
  • GT_GRC == GT-GRC
  • EGT_GRC == EGT-GRC
  • BT_GRC == BT-GRC
  • BT_GRC_OS == BT-GRC OS (also BT-GRC OneSoft)
  • EBT_GRC == EBT-GRC
  • EBT_GRC_noslice == EBT-GRC (-slice)
  • EBT_GRC512 == EBT-GRC (512)
  • EBT_GRC512_noslice == EBT-GRC (-slice,512)
  • GAU_IN == GAU
  • EGT_GAU_IN == EBT-GAU
  • EBT_GAU_IN == EBT-GAU
  • S4DStack == S4D
  • BalancedTreeGRC == BBT-GRC
  • HGRC == RIR-GRC
  • HCRvNN == RIR-CRvNN
  • HOM == RIR-OM
  • HEBT_GRC == RIR-EBT-GRC
  • HEBT_GRC_noSSM == RIR-EBT-GRC ($-$S4D)
  • HEBT_GRC_noRBA == RIR-EBT-GRC ($-$Beam Align)
  • HEBT_GRC_random == RIR-EBT-GRC ($+$Random Align)
  • HEBT_GRC_small == RIR-EBT-GRC (beam 5)
  • HEBT_GRC_chunk20 == RIR-EBT-GRC (chunk 20)
  • HEBT_GRC_chunk10 == RIR-EBT-GRC (chunk 10)
  • MEGA == MEGA

Citations:

@InProceedings{pmlr-v139-chowdhury21a,
  title = 	 {Modeling Hierarchical Structures with Continuous Recursive Neural Networks},
  author =       {Chowdhury, Jishnu Ray and Caragea, Cornelia},
  booktitle = 	 {Proceedings of the 38th International Conference on Machine Learning},
  pages = 	 {1975--1988},
  year = 	 {2021},
  editor = 	 {Meila, Marina and Zhang, Tong},
  volume = 	 {139},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {18--24 Jul},
  publisher =    {PMLR},
  pdf = 	 {http://proceedings.mlr.press/v139/chowdhury21a/chowdhury21a.pdf},
  url = 	 {http://proceedings.mlr.press/v139/chowdhury21a.html},
  abstract = 	 {Recursive Neural Networks (RvNNs), which compose sequences according to their underlying hierarchical syntactic structure, have performed well in several natural language processing tasks compared to similar models without structural biases. However, traditional RvNNs are incapable of inducing the latent structure in a plain text sequence on their own. Several extensions have been proposed to overcome this limitation. Nevertheless, these extensions tend to rely on surrogate gradients or reinforcement learning at the cost of higher bias or variance. In this work, we propose Continuous Recursive Neural Network (CRvNN) as a backpropagation-friendly alternative to address the aforementioned limitations. This is done by incorporating a continuous relaxation to the induced structure. We demonstrate that CRvNN achieves strong performance in challenging synthetic tasks such as logical inference (Bowman et al., 2015b) and ListOps (Nangia & Bowman, 2018). We also show that CRvNN performs comparably or better than prior latent structure models on real-world tasks such as sentiment analysis and natural language inference.}
}
@InProceedings{Chowdhury2023beam,
  title = 	 {Beam Tree Recursive Cells},
  author =       {Ray Chowdhury, Jishnu and Caragea, Cornelia},
  booktitle = 	 {Proceedings of the 40th International Conference on Machine Learning},
  year = 	 {2023}
}
@inproceedings{
chowdhury2023efficient,
title={Efficient Beam Tree Recursion},
author={Jishnu Ray Chowdhury and Cornelia Caragea},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=PR5znB6BZ2}
}
@inproceedings{
chowdhury2023recursion,
title={Recursion in Recursion: Two-Level Nested Recursion for Length Generalization with Scalability},
author={Jishnu Ray Chowdhury and Cornelia Caragea},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=o6yTKfdnbA}
}

For any issues contact: jishnu.ray.c@gmail.com