This repository contains the training code of TBT introduced in our work: "Binary and Ternary Natural Language Generation", published in ACL 2023.
We approach the problem with a mix of statistics-based quantization for the weights and elastic quantization of the activations and demonstrate the first ternary and binary transformer models on the downstream tasks of summarization and machine translation.
If you find our code useful for your research, please consider citing:
@article{liu2023binary,
title={Binary and Ternary Natural Language Generation},
author={Liu, Zechun and Oguz, Barlas and Pappu, Aasish and Shi, Yangyang and Krishnamoorthi, Raghuraman},
booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics}
year={2023}
}
Our previous papers related to binarizing BERT model:
- python 3.9.12, pytorch 1.12.1
- Download pretrained models from hugging face model zoo.
Dataset Finetuned full-precision model XSUM bart-base-xsum CNN/DailyMail bart-base-cnn
- For XSUM benchmark,
bash scrips/run_xsum.sh $w_bit $a_bit $lr
. - For CNN/DailyMail benchmark,
bash scrips/run_cnn.sh $w_bit $a_bit $lr
. - Learning rate for each model:
XSUM | CNN/DailyMail | |
---|---|---|
W2A8 | 3e-4 | 1e-4 |
W2A2 | 3.5e-4 | 7e-4 |
W1A8 | 2.5e-4 | 1.5e-4 |
W1A1 | 5e-4 | 5e-4 |
XSUM | CNN | ||||||||
---|---|---|---|---|---|---|---|---|---|
#Bits | Size (M) | FLOPs (G) | R1 | R2 | RL | R1 | R2 | RL | |
BART | 32-32-32 | 532.0 | 1x | 43.84 | 20.79 | 35.71 | 44.90 | 22.25 | 42.09 |
QuantBart | 8 - 8 - 8 | 138.1 | -- | 40.25 | 17.78 | 32.70 | -- | -- | -- |
DQ-BART | 8 - 8 - 8 | 138.1 | -- | 42.51 | 19.61 | 34.61 | 44.66 | 21.92 | 41.86 |
Ternary | |||||||||
Baseline (TWN) | 2 - 2 - 8 | 39.6 | 0.25x | 39.99 | 17.13 | 31.99 | 42.99 | 20.05 | 40.18 |
QuantBart | 2 - 2 - 8 | 39.6 | 0.25x | 39.15 | 16.72 | 31.72 | -- | -- | -- |
DQ-BART | 2 - 2 - 8 | 39.6 | 0.25x | 40.06 | 17.34 | 32.46 | 42.94 | 20.07 | 40.13 |
TBT | 2 - 2 - 8 | 39.6 | 0.25x | 42.40 | 19.54 | 34.51 | 43.46 | 20.52 | 40.58 |
Baseline (TWN) | 2 - 2 - 2 | 39.6 | 0.0625x | 12.80 | 1.21 | 11.4 | 12.92 | 0.32 | 12.42 |
TBT | 2 - 2 - 2 | 39.6 | 0.0625x | 36.21 | 14.38 | 29.07 | 41.03 | 18.18 | 38.30 |
Binary | |||||||||
Baseline (BWN) | 1 - 1 - 8 | 23.2 | 0.125x | 1.90 | 0.01 | 1.78 | 2.78 | 0.08 | 2.48 |
TBT | 1 - 1 - 8 | 23.2 | 0.125x | 40.96 | 18.37 | 33.30 | 42.66 | 19.72 | 39.80 |
Baseline (BWN) | 1 - 1 - 1 | 23.2 | 0.0156x | 1.90 | 0.01 | 1.78 | 2.78 | 0.08 | 2.48 |
TBT | 1 - 1 - 1 | 23.2 | 0.0156x | 31.68 | 11.19 | 25.29 | 35.56 | 11.71 | 33.23 |
The original code is borrowed from DQ-BART.
Zechun Liu, Reality Labs, Meta Inc (liuzechun0216 at gmail.com)
BiT is CC-BY-NC 4.0 licensed as of now.