Abstractive text summarization using BERT

This is the models using BERT (refer the paper Pretraining-Based Natural Language Generation for Text Summarization ) for one of the NLP(Natural Language Processing) task, abstractive text summarization.


  • Python 3.6.5+
  • Pytorch 0.4.1+
  • Tensorflow
  • Pandas
  • tqdm
  • Numpy
  • MeCab
  • Tensorboard X and others...

All packages used here can be installed by pip as follow:

pip install -r requirement.txt


If you train the model with GPU, it is easy to use Pytorch docker images in DockerHub.

In this study, pytorch/pytorch:0.4.1-cuda9-cudnn7-devel(2.62GB) has been used.

Before using

When you use this, please follow the steps below.

  1. Make a repository named "/data/checkpoint" under root. And put bert_model, vocabulary file and config file for bert. These files can be download here.

  2. Put data file for training and validate under /workspace/data/. The format is as follow:

data = {
    'settings': opt,
    'dict': {
        'src': text2token,
        'tgt': text2token},
    'train': {
        'src': content[:100000],
        'tgt': summary[:100000]},
    'valid': {
        'src': content[100000:],
        'tgt': summary[100000:]}}
torch.save(data, opt.save_data)

overall directory structure is as follow:

`-- data                        # under workspace 
    |-- checkpoint
    |   |-- bert_config.json    # BERT config file
    |   |-- pytorch_model.bin   # BERT model file
    |   `-- vocab.txt           # vocabulary file
    `-- preprocessed_data.data  # train and valid data file


Name Value
Encoder BERT
Decoder Transformer (Only Decoder)
Embed dimension 768
Hidden dimension 3072
Encoder layers 12
Decoder layers 8
Optimizer Adam
Learning rate init=0.0001
Wormup step 4000
Input max length 512
Batch size 4


Train the model

python train.py -data data/preprocessed_data.data -bert_path data/checkpoint/ -proj_share_weight -label_smoothing -batch_size 4 -epoch 10 -save_model trained -save_mode best

Generate summarization with trained model

python summarize.py -model data/checkpoint/trained/trained.chkpt -src data/preprocessed_data.data -vocab data/checkpoint/vocab.txt -output pred.txt


Tensorboard X image



  • Eval the model with score such as ROUGE-N
  • Make some examples
