/lite-transformer

[ICLR 2020] Lite Transformer with Long-Short Range Attention

Primary LanguagePythonOtherNOASSERTION

Lite Transformer with Long-Short Range Attention

@inproceedings{Wu2020LiteTransformer,
  title={Lite Transformer with Long-Short Range Attention},
  author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020}
}

Overview

We release the PyTorch code for the Lite Transformer. [Paper|Website|Slides]: overview

Consistent Improvement by Tradeoff Curves

tradeoff

Save 20000x Searching Cost of Evolved Transformer

et

Further Compress Transformer by 18.2x

compression

How to Use

Prerequisite

  • Python version >= 3.6
  • PyTorch version >= 1.0.0
  • configargparse >= 0.14
  • For training new models, you'll also need an NVIDIA GPU and NCCL

Installation

  1. Codebase

    To install fairseq from source and develop locally:

    pip install --editable .
  2. Costumized Modules

    We also need to build the lightconv and dynamicconv for GPU support.

    Lightconv_layer

    cd fairseq/modules/lightconv_layer
    python cuda_function_gen.py
    python setup.py install

    Dynamicconv_layer

    cd fairseq/modules/dynamicconv_layer
    python cuda_function_gen.py
    python setup.py install

Data Preparation

IWSLT'14 De-En

We follow the data preparation in fairseq. To download and preprocess the data, one can run

bash configs/iwslt14.de-en/prepare.sh

WMT'14 En-Fr

We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

bash configs/wmt14.en-fr/prepare.sh

WMT'16 En-De

We follow the data pre-processing in fairseq. One should first download the preprocessed data from the Google Drive provided by Google. To binarized the data, one can run

bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file]

WIKITEXT-103

As the language model task has many additional codes, we place it in another branch: language-model. We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

git checkout language-model
bash configs/wikitext-103/prepare.sh

Testing

For example, to test the models on WMT'14 En-Fr, one can run

configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid]

For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run

configs/wmt14.en-fr/test.sh embed496/ 0 test

We provide several pretrained models at the bottom. You can download the model and extract the file by

tar -xzvf [filename]

Training

We provided several examples to train Lite Transformer with this repo:

To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run

python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml

To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32

In general, to train a model, one can run

python train.py [path to the data binary] --configs [path to config file] [override options]

Note that --update-freq should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs).

Distributed Training (optional)

To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs.

# On host1
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=0 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8
# On host2
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=1 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8

Models

We provide the checkpoints for our Lite Transformer reported in the paper:

Dataset #Mult-Adds Test Score Model and Test Set
WMT'14 En-Fr 90M 35.3 download
360M 39.1 download
527M 39.6 download
WMT'16 En-De 90M 22.5 download
360M 25.6 download
527M 26.5 download
CNN / DailyMail 800M 38.3 (R-L) download
WIKITEXT-103 1147M 22.2 (PPL) download