Official implementation of Selective Token Generation for Few-shot Natural Language Generation (COLING'22)
- torch == 1.8.1
- nltk == 3.5
- rouge-score == 0.04
- transformers == 4.3.2
- download this repo into qa/src
> git clone https://github.com/microsoft/MSMARCO-Question-Answering qa/src
- download MSMARCO QA dataset from https://microsoft.github.io/msmarco/
- preprocess data by using this script
- change root_dir of QADataset to the corresponding directory
- download this repo into summ/src
> git clone https://github.com/SKRohit/Generating_Text_Summary_With_GPT2 summ/src
- download and preprocess dataset following Dataset Preparation of the repo
- change root_dir of SummDataset to the corresponding directory
-
specify -e [experiment_name] in below instructions (e.g. -e qa)
- qa (Question Answering)
- summ (Summarization)
-
specify -d [domain_name] in below instructions (e.g. -e qa -d 1)
- qa: 1, 05, 01, 001, 005 (2,000, 1,000, 500, 100, and 50 shot respectively)
- summ: CNN, CNN05, CNN01, CNN003, CNN001 (3,000, 1,500, 300, 100, and 50 shot respectively)
- fine-tune GPT
> python [experiment_name]/train.py --domain [domain_name] --seed [seed_number]
- train
- NonSTG-MLE
> python -m torch.distributed.launch --nproc_per_node=1 --master_port [PORT] train.py --world_size 1 --num_workers 2 -e [experiment_name] -d [domain_name] --seed [seed_number] -m ftg --obj mle
- NonSTG-RL
> python -m torch.distributed.launch --nproc_per_node=1 --master_port [PORT] train.py --world_size 1 --num_workers 2 -e [experiment_name] -d [domain_name] --seed [seed_number] -m ftg --obj rl
- STG
> python -m torch.distributed.launch --nproc_per_node=1 --master_port [PORT] train.py --world_size 1 --num_workers 2 -e [experiment_name] -d [domain_name] --seed [seed_number] -m stg --obj rl
- generation & evaluation
-
PLM
> python -m torch.distributed.launch --nproc_per_node=8 --master_port [PORT] eval_[experiment_name].py --world_size 8 --num_workers 8 -m ft -d [domain_name] --seed [seed_number]
-
NonSTGs
> (QA) python -m torch.distributed.launch --nproc_per_node=8 --master_port [PORT] eval_qa.py --world_size 8 --num_workers 8 -m rl -cp [checkpoint_path] --score sample -n 3 > (summ) python -m torch.distributed.launch --nproc_per_node=8 --master_port [PORT] eval_summ.py --world_size 8 --num_workers 8 -m rl -cp [checkpoint_path] --score beam -n 3
-
STG
> (QA) python -m torch.distributed.launch --nproc_per_node=8 --master_port [PORT] eval_qa.py --world_size 8 --num_workers 8 -m rl -cp [checkpoint_path] --score sample -n 3 --scheme cat > (summ) python -m torch.distributed.launch --nproc_per_node=8 --master_port [PORT] eval_summ.py --world_size 8 --num_workers 8 -m rl -cp [checkpoint_path] --score beam -n 3
-
Naive Ensemble (available in NonSTG)
specify --inj_scheme [max, mix, random] in the instruction of evaluation of NonSTGs
@inproceedings{jo-etal-2022-stg,
title = "Selective Token Generation for Few-shot Natural Language Generation",
author = "Jo, Daejin and Kwon, Taehwan and Kim, Eun-Sol and Kim, Sungwoong",
booktitle = "Proceedings of the 29th International Conference on Computational Linguistics",
publisher = "International Committee on Computational Linguistics",
url = "https://aclanthology.org/2022.coling-1.510",
pages = "5837--5856"
}
If you have any questions, feel free to contact me via email.