Ternary_Binary_Transformer

This repository contains the training code of TBT introduced in our work: "Binary and Ternary Natural Language Generation", published in ACL 2023.

We approach the problem with a mix of statistics-based quantization for the weights and elastic quantization of the activations and demonstrate the first ternary and binary transformer models on the downstream tasks of summarization and machine translation.

Citation

If you find our code useful for your research, please consider citing:

@article{liu2023binary,
title={Binary and Ternary Natural Language Generation},
author={Liu, Zechun and Oguz, Barlas and Pappu, Aasish and Shi, Yangyang and Krishnamoorthi, Raghuraman},
booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics}
year={2023}
}

Our previous papers related to binarizing BERT model:

  • BiT: Robustly Binarized Multi-distilled Transformer (NeurIPS 2022) [code] [paper]

Run

1. Requirements:

  • python 3.9.12, pytorch 1.12.1

2. Pretrained models:

  • Download pretrained models from hugging face model zoo.
    Dataset Finetuned full-precision model
    XSUM bart-base-xsum
    CNN/DailyMail bart-base-cnn

3. Steps to run:

  • For XSUM benchmark, bash scrips/run_xsum.sh $w_bit $a_bit $lr .
  • For CNN/DailyMail benchmark, bash scrips/run_cnn.sh $w_bit $a_bit $lr .
  • Learning rate for each model:
XSUM CNN/DailyMail
W2A8 3e-4 1e-4
W2A2 3.5e-4 7e-4
W1A8 2.5e-4 1.5e-4
W1A1 5e-4 5e-4

Models

XSUM CNN
#Bits Size (M) FLOPs (G) R1 R2 RL R1 R2 RL
BART 32-32-32 532.0 1x 43.84 20.79 35.71 44.90 22.25 42.09
QuantBart 8 - 8 - 8 138.1 -- 40.25 17.78 32.70 -- -- --
DQ-BART 8 - 8 - 8 138.1 -- 42.51 19.61 34.61 44.66 21.92 41.86
Ternary
Baseline (TWN) 2 - 2 - 8 39.6 0.25x 39.99 17.13 31.99 42.99 20.05 40.18
QuantBart 2 - 2 - 8 39.6 0.25x 39.15 16.72 31.72 -- -- --
DQ-BART 2 - 2 - 8 39.6 0.25x 40.06 17.34 32.46 42.94 20.07 40.13
TBT 2 - 2 - 8 39.6 0.25x 42.40 19.54 34.51 43.46 20.52 40.58
Baseline (TWN) 2 - 2 - 2 39.6 0.0625x 12.80 1.21 11.4 12.92 0.32 12.42
TBT 2 - 2 - 2 39.6 0.0625x 36.21 14.38 29.07 41.03 18.18 38.30
Binary
Baseline (BWN) 1 - 1 - 8 23.2 0.125x 1.90 0.01 1.78 2.78 0.08 2.48
TBT 1 - 1 - 8 23.2 0.125x 40.96 18.37 33.30 42.66 19.72 39.80
Baseline (BWN) 1 - 1 - 1 23.2 0.0156x 1.90 0.01 1.78 2.78 0.08 2.48
TBT 1 - 1 - 1 23.2 0.0156x 31.68 11.19 25.29 35.56 11.71 33.23

Acknowledgement

The original code is borrowed from DQ-BART.

Contact

Zechun Liu, Reality Labs, Meta Inc (liuzechun0216 at gmail.com)

License

BiT is CC-BY-NC 4.0 licensed as of now.