
Code for "Democratizing Reasoning Ability: Tailored Learning from Large Language Model", EMNLP 2023

Primary LanguagePython

Democratizing Reasoning Ability

Code for paper "Democratizing Reasoning Ability: Tailored Learning from Large Language Model" accepted by EMNLP 2023.



  1. python 3.8.16
  2. pytorch 1.13.1
  3. transformers 4.28.1
  4. accelerate 0.18.0
  5. datasets 2.10.1
  6. deepspeed 0.9.1


Run the following command to train the student LM GPT-J-6B on the GSM8K dataset for the initial round of learning. We have released the collected rationales in the GitHub release page for research use. If your device supports bf16, please replace --fp16 to --bf16 for a more stable training.

python run_trainer.py \
    --device "0,1,2,3,4,5,6,7" \
    --task "gsm8k" \
    --load "EleutherAI/gpt-j-6B" \
    --save_best '0' \
    --epoch "10" \
    --teacher_data "localdataset/gsm8k/gsm8k.train.round1.json" \
    --fewshot "yes" \
    --train_bsz "2" \
    --fp16 "yes" \
    --contrastive "yes" \
    --cl_pos_path "localdataset/gsm8k/gsm8k.train.round1.json" \
    --cl_neg_path "localdataset/gsm8k/gsm8k.student.train.round0.neg4.json" \
    --cl_ratio "0.5" \
    --merge_losses "yes" \
    --do_eval "yes" \
    --weight_decay "0.01" \
    --lr "7e-6" \
    --deepspeed "ds_stage3_config.json"


python infer.py --task gsm8k --test_on test --device $(seq -s , 0 7) --fewshot --rationale --checkpoint checkpoints/student_train/saved_ckpt_path

For the Next Round

Take the "exam" on the train set first to collect the self-made mistakes.

python infer_student_wrong.py \
    --task gsm8k \
    --rationale --fewshot \
    --batch 4 --test_on train \
    --note "round1" --device $(seq -s , 0 7) \
    --checkpoint checkpoints/student_train/saved_ckpt_path 

Use the prompt template mentioned in the paper to collect the teacher's feedback with these mistakes from ChatGPT. Replace the --teacher_data, --cl_pos_path, --cl_neg_path and update--checkpoint path_to_ckpt_of_round1 in the above training command to start the next round of training. Please refer to args.py for detailed usage.


