/SentBS

repo for "SentBS: Sentence-level Beam Search for Controllable Summarization"

Primary LanguagePython

SentBS: Sentence-level Beam Search for Controllable Summarization

This repository contains code and related resources of our paper "SentBS: Sentence-level Beam Search for Controllable Summarization".


If you find our paper and resources useful, please kindly leave a star and cite our papers. Thanks!

@article{shen2022sentbs,
  title={SentBS: Sentence-level Beam Search for Controllable Summarization},
  author={Shen, Chenhui and Cheng, Liying and Bing, Lidong and You, Yang and Si, Luo},
  journal={EMNLP 2022},
  year={2022}
}

@article{shen2022mred,
  title={MReD: A Meta-Review Dataset for Structure-Controllable Text Generation},
  author={Shen, Chenhui and Cheng, Liying and Zhou, Ran and Bing, Lidong and You, Yang and Si, Luo},
  journal={Findings of ACL},
  year={2022}
}

Catalogue:


1. Introduction: [Back to Top]

A wide range of control perspectives have been explored in controllable text generation. Structure-controlled summarization is recently proposed as a useful and interesting research direction. However, current structure-controlling methods have limited effectiveness in enforcing the desired structure. To address this limitation, we propose a sentence-level beam search generation method (SentBS), where evaluation is conducted throughout the generation process to select suitable sentences for subsequent generations. We experiment with different combinations of decoding methods to be used as subcomponents by SentBS and evaluate results on the structure-controlled dataset MReD. Experiments show that all explored combinations for SentBS can improve the agreement between the generated text and the desired structure, with the best method significantly reducing the structural discrepancies suffered by the existing model, by approximately 68%.


2. Running our Code

2.1. Pre-requisites: [Back to Top]

For our code, we use the Huggingface Transformers of version 4.16.2. To install a specific verison of transformers, check out here.

For the Bert-Score metric, we follow this repository.

We use the public check point of Bart-Large pretrained on CNN/DM as our base architecture.

2.2. Commands to reproduce our results: [Back to Top]

For all experiments below, please download our processed data from here. Unzip the downloaded data and place all data folders under the root folder named /data.

2.2.1. Reproduce Sent-Ctrl (Table 1 upper section): [Back to Top]

We include the reformatted data used for our experiments. The original data can also be obtained from here.

To reproduce the sent-ctrl baseline, run:

CUDA_VISIBLE_DEVICES=0 python ctrl_transformer.py --model_name_or_path facebook/bart-large-cnn --do_train --do_eval --do_predict --train_file data/original_clean/train_rate_concat_sent-ctrl.csv --validation_file data/original_clean/val_rate_concat_sent-ctrl.csv --test_file data/original_clean/test_rate_concat_sent-ctrl.csv --output_dir ./results/sentctrl_reproduced  --seed 0 --save_total_limit 3 --gen_target_max 800 --predict_with_generate --eval_steps 500 --max_source_length 2048
2.2.2. Train Classifier: [Back to Top]

For the MReD dataset, we additionally train a sentence classifier so that during generation, the selection of sentence options is based on both the category classification score as well as the sequence likelihood.

The classifier is trained on the LSTM-labelled training data split.

The base architecture used for the classifier is the huggingface Roberta-Large model.

CUDA_VISIBLE_DEVICES=0 python train_sent_classifier.py --model_path roberta-large
2.2.3. Reproduce Sent-Ctrl + SentBS (Table 1 upper section): [Back to Top]

For the following commands, you may adjust the k value with the flag --gen_size.

  • For nucleus sampling:
CUDA_VISIBLE_DEVICES=0 python beam_search_sent.py --gen_size 8 --beam_size 4 --top_p 0.9 --res_dir results/sampling --generation_model_path results/sentctrl_reproduced --test_file data/original_clean/test_rate_concat_sent-ctrl.csv --gen_mode sample --write --eval_rouge --load_classifier --classification_model_path <path_to_classification_model>
  • For beam sampling:
CUDA_VISIBLE_DEVICES=0 python beam_search_sent.py --gen_size 8 --beam_size 4 --top_p 0.9 --res_dir results/beam_sampling --generation_model_path results/sentctrl_reproduced --test_file data/original_clean/test_rate_concat_sent-ctrl.csv --gen_mode beam_sample --write --eval_rouge --load_classifier --classification_model_path <path_to_classification_model>
  • For beam search + nucleus sampling
CUDA_VISIBLE_DEVICES=0 python beam_search_sent.py --gen_size 8 --beam_size 4 --top_p 0.9 --res_dir results/mixed_bs_ns --generation_model_path results/sentctrl_reproduced --test_file data/original_clean/test_rate_concat_sent-ctrl.csv --gen_mode beam_search_sent --write --eval_rouge --load_classifier --classification_model_path <path_to_classification_model>
  • For beam search + beam sampling + nucleus sampling
CUDA_VISIBLE_DEVICES=0 python beam_search_sent.py --gen_size 8 --beam_size 4 --top_p 0.9 --res_dir results/mixed_all --generation_model_path results/sentctrl_reproduced --test_file data/original_clean/test_rate_concat_sent-ctrl.csv --gen_mode beam_search_sent  --beam_sample --write --eval_rouge --load_classifier --num_beam_sample_gen 4 --classification_model_path <path_to_classification_model>

You may use the flag --num_beam_sample_gen to control the number of sentencens generated by beam sampling.

2.2.4. Reproduce Seg-Ctrl and Seg-Ctrl + SentBS (Table 1 bottom section): [Back to Top]

To reproduce the seg-ctrl baseline, run:

CUDA_VISIBLE_DEVICES=0 python ctrl_transformer.py --model_name_or_path facebook/bart-large-cnn --do_train --do_eval --do_predict --train_file data/original_seg_clean/train.csv --validation_file data/original_seg_clean/val.csv --test_file data/original_seg_clean/test.csv --output_dir results/segctrl_reproduced  --seed 0 --save_total_limit 3 --gen_target_max 800 --predict_with_generate --eval_steps 500 --max_source_length 2048

For seg-ctrl+SentBS, run

CUDA_VISIBLE_DEVICES=0 python segctrl_sentbs.py --res_dir results/segctrl_sentbs --generation_model_path results/segctrl_reproduced --test_file data/original_seg_clean/test.csv --gen_mode beam_search_sent --load_classifier --classification_model_path ../ecpe_transformer/mred_sentence_classification/roberta-large/ --gen_size 8 --beam_size 4 --beam_sample --eval_rouge --run_num 0 --write