This repo contains the code and data of the following paper accepeted by ICLR 2020
Additive Power-of-Two Quantization: An Efficient Non-uniform Discretization For Neural Networks
training codes will be open sourced soon.
@inproceedings{Li2020Additive,
title={Additive Powers-of-Two Quantization: An Efficient Non-uniform Discretization for Neural Networks},
author={Yuhang Li and Xin Dong and Wei Wang},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=BkgXT24tDS}
}
Pytorch 1.1.0 with CUDA
- Please prepare the ImageNet validation and training dataset, we use official example code here to provide dataloader.
- The CIFAR10 dataset can be download automatically (update soon).
models.quant_layer.py
contains the configuration for quantization. In particular, you can specify them in the class QuantConv2d
:
class QuantConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.layer_type = 'QuantConv2d'
self.bit = 4
self.weight_quant = weight_quantize_fn(w_bit=self.bit, power=True)
self.act_grid = build_power_value(self.bit, additive=True)
self.act_alq = act_quantization(self.bit, self.act_grid, power=True)
self.act_alpha = torch.nn.Parameter(torch.tensor(8.0))
Here, self.bit
controls the bitwidth; weight_quantize_fn
controls the quantization scheme, where power=True
means using PoT or APoT quantization. build_power_value
construct the levels set Q^a(1, b) with parameter bit
and additive
.
To train a 5-bit model, just run main.py:
python main.py -a resnet18 --bit 5
Progressive initialization requires checkpoint of higher bitwidth. For example
python main.py -a resnet18 --bit 4 --pretrained checkpoint/res18_5best.pth.tar
We provide a function show_params()
to print the clipping parameter in both weights and activations
The training code is inspired by pytorch-cifar-code from junyuseu.
The dataset can be downloaded automatically using torchvision. We provide the shell script to progressively train full precision, 4, 3, and 2 bit models. For example, train_res20.sh
:
#!/usr/bin/env bash
python main.py --arch res20 --bit 32 -id 0,1 --wd 5e-4
python main.py --arch res20 --bit 4 -id 0,1 --wd 1e-4 --lr 4e-2 \
--init result/res20_32bit/model_best.pth.tar
python main.py --arch res20 --bit 3 -id 0,1 --wd 1e-4 --lr 4e-2 \
--init result/res20_4bit/model_best.pth.tar
python main.py --arch res20 --bit 2 -id 0,1 --wd 3e-5 --lr 4e-2 \
--init result/res20_3bit/model_best.pth.tar
The checkpoint models for CIFAR10 are released:
Model | Precision | Accuracy | Checkpoints |
---|---|---|---|
Res20 | Full Precision | 92.96 | Res20_32bit |
Res20 | 4-bit | 92.45 | Res20_4bit |
Res20 | 3-bit | 92.49 | Res20_3bit |
Res20 | 2-bit | 90.96 | Res20_2bit |
Res56 | Full Precision | 94.46 | Res56_32bit |
Res56 | 4-bit | 93.93 | Res56_4bit |
Res56 | 3-bit | 93.77 | Res56_3bit |
Res56 | 2-bit | 93.05 | Res56_2bit |
To evluate the models, you can run
python main.py -e --init result/res20_3bit/model_best.pth.tar -e -id 0 --bit 3
And you will get the output of accuracy and the value of clipping threshold in weights & acts:
Test: [0/100] Time 0.221 (0.221) Loss 0.2144 (0.2144) Prec 96.000% (96.000%)
* Prec 92.510%
clipping threshold weight alpha: 1.569000, activation alpha: 1.438000
clipping threshold weight alpha: 1.278000, activation alpha: 0.966000
clipping threshold weight alpha: 1.607000, activation alpha: 1.293000
clipping threshold weight alpha: 1.426000, activation alpha: 1.055000
clipping threshold weight alpha: 1.364000, activation alpha: 1.720000
clipping threshold weight alpha: 1.511000, activation alpha: 1.434000
clipping threshold weight alpha: 1.600000, activation alpha: 2.204000
clipping threshold weight alpha: 1.552000, activation alpha: 1.530000
clipping threshold weight alpha: 0.934000, activation alpha: 1.939000
clipping threshold weight alpha: 1.427000, activation alpha: 2.232000
clipping threshold weight alpha: 1.463000, activation alpha: 1.371000
clipping threshold weight alpha: 1.440000, activation alpha: 2.432000
clipping threshold weight alpha: 1.560000, activation alpha: 1.475000
clipping threshold weight alpha: 1.605000, activation alpha: 2.462000
clipping threshold weight alpha: 1.436000, activation alpha: 1.619000
clipping threshold weight alpha: 1.292000, activation alpha: 2.147000
clipping threshold weight alpha: 1.423000, activation alpha: 2.329000
clipping threshold weight alpha: 1.428000, activation alpha: 1.551000
clipping threshold weight alpha: 1.322000, activation alpha: 2.574000
clipping threshold weight alpha: 1.687000, activation alpha: 1.314000
- checkpoints for ImageNet models