/DistilExt

IEEE AIKE'2020: Knowledge Distillation on Extractive Summarization

Primary LanguagePythonMIT LicenseMIT

DistilExt

Python version: This code is in Python3.6

Package Requirements: torch==1.5.0 pytorch_transformers tensorboardX multiprocess pyrouge

Some codes are borrowed from ONMT(https://github.com/OpenNMT/OpenNMT-py)

Trained Teacher Models

CNN/DM BertExt

XSum BertExt

Trained Student Models

CNN/DM DistilExt (8-layer Transformer)

XSum DistilExt (6-layer Transformer)

Data Preparation

For the steps of data preprocessing, please visit PreSumm for more information.
We provide our pre-processed data here.

CNN/DailyMail

CNN/DM

CNN/DM (soft_targets)

XSum

XSum

XSum (soft_targets)

Model Training

First run: For the first time, you should use single-GPU, so the code can download the BERT model. Use -visible_gpus -1, after downloading, you could kill the process and rerun the code with multi-GPUs. The scripts below are in the folder src.

CNN/DM

To train a student

bash cnndm_train.sh

XSum

To train a teacher

bash xsum_train.sh

To train a student

bash xsum_train_stu.sh

Model Evaluation

CNN/DM

# this shell script will validate all the saved model steps during training
bash cnndm_val.sh

XSum

# this shell script will validate all the saved model steps during training
bash xsum_val.sh

Testing for a single step

CNN/DM

bash cnndm_test_single.sh

XSum

# test teacher
bash xsum_test_teacher.sh
# test student
bash xsum_test_single.sh
  • -mode can be {validate, test}, where validate will inspect the model directory and evaluate the model for each newly saved checkpoint, test need to be used with -test_from, indicating the checkpoint you want to use
  • MODEL_PATH is the directory of saved checkpoints
  • use -mode valiadte with -test_all, the system will load all saved checkpoints and select the top ones to generate summaries (this will take a while)