/StackingBERT

Source code for "Efficient Training of BERT by Progressively Stacking"

Primary LanguagePythonOtherNOASSERTION

Introduction

This repository is the code to reproduce the result of Efficient Training of BERT by Progressively Stacking. The code is based on Fairseq.

Requirements and Installation

  • PyTorch >= 1.0.0
  • For training new models, you'll also need an NVIDIA GPU and NCCL
  • Python version 3.7

After PyTorch is installed, you can install requirements with:

pip install -r requirements.txt

Getting Started

Step 1:

bash install.sh

This script downloads:

  1. Moses Decoder
  2. Subword NMT
  3. Fast BPE (In the next steps, we use Subword NMT instead of Fast BPE. Recommended if you want to generate your own dictionary on a large-scale dataset.)

These library will do cleaning, tokenization, and BPE encoding for GLUE data in step 3. They will also be helpful if you want to make your own corpus for BERT training or if you want to test our model on your own tasks.

Step 2:

bash reproduce_bert.sh

This script runs progressive stacking and train a BERT. The code is tested on 4 Tesla P40 GPUs (24GB Gmem). For different hardware, you probably need to change the maximum number of tokens per batch (by changing max-tokens and update-freq).

Step 3:

bash reproduce_glue.sh

This script fine-tunes the BERT trained in step 2. The script chooses the checkpoint trained for 400K steps, which is the same as the stacking model in our paper.

Cite

@InProceedings{pmlr-v97-gong19a,
  title = 	 {Efficient Training of {BERT} by Progressively Stacking},
  author = 	 {Gong, Linyuan and He, Di and Li, Zhuohan and Qin, Tao and Wang, Liwei and Liu, Tieyan},
  booktitle = 	 {Proceedings of the 36th International Conference on Machine Learning},
  pages = 	 {2337--2346},
  year = 	 {2019},
  editor = 	 {Chaudhuri, Kamalika and Salakhutdinov, Ruslan},
  volume = 	 {97},
  series = 	 {Proceedings of Machine Learning Research},
  address = 	 {Long Beach, California, USA},
  month = 	 {09--15 Jun},
  publisher = 	 {PMLR},
  pdf = 	 {http://proceedings.mlr.press/v97/gong19a/gong19a.pdf},
  url = 	 {http://proceedings.mlr.press/v97/gong19a.html},
}