This folder contains the code for inducing mixed precision quantization schemes with BSQ on the CIFAR-10 dataset. The code for ResNet models are configured into bit representation, so as to support BSQ training and to achieve the results in the main paper.
The training and evaluation codes and the model architectures are adapted from bearpaw/pytorch-classification.
Clone recursively:
git clone --recursive https://github.com/yanghr/BSQ.git
This code is tested with Python 3.6.8, PyTorch 1.2.0 and TorchVision 0.4.0. It is recommanded to use the provided spec-file.txt
file to replicate the anaconda environment used for testing this code, which can be done by:
conda create --name myenv --file spec-file.txt
We suggest using GPU to run this code for the best efficiency. Both running on a single GPU or running in parallel on multiple GPUs are supported.
As introduced in Appendix A.1, pretrained models are used to initiate the BSQ training. The pretrained model are provided in the \checkpoints\cifar10\
folder, where the checkpoint in resnet-20\
is the full-precision pretrained model and the checkpoint in resnet-20-8\
is the 8-bit quantized model in bit representation.
For more details on training the full-precision model please see the training recipes provided by bearpaw/pytorch-classification. The quantized model is achieved with convert.py
, which will be introduced later.
Here we perform BSQ training on the ResNet-20 model on the CIFAR-10 dataset.
python cifar_prune_STE.py -a resnet --depth 20 --epochs 350 --lr 0.1 --schedule 250 --gamma 0.1 --wd 1e-4 --model checkpoints/cifar10/resnet-20-8/model_best.pth.tar --decay 0.01 --Prun_Int 100 --thre 0.0 --checkpoint checkpoints/cifar10/xxx --Nbits 8 --act 4 --bin --L1 >xxx.txt
xxx
in the command should be replaced with the folder you want for saving the achieved model. The achieved model will be saved in bit representation. We suggest redirecting the print output to a txt file with >xxx.txt
to avoid messing up with the progress bar display and keep record of the training process.
--decay
is used to set the regularization strength
--Prun_Int
is the number of epochs between each re-quantization and precision adjustment step, which is suggested to be set to 100. The effect of using other intervals are illustrated in Appendix B.1.
--act
indicates the quantization precision of the activation in the model. Should be kept the same in BSQ training and finetuning. Default value set to 4.
The model achieved from BSQ training can be evaluated and finetuned with cifar_finetune.py
.
For evaluation, run
python cifar_finetune.py -a resnet --depth 20 --model checkpoints/cifar10/xxx/checkpoint.pth.tar --Nbits 8 --act 4 --bin --evaluate
xxx
in the command should be replaced with the folder used to save the BSQ trained model. Note that only model in bit representation can be evaluated in this way. The testing accuracy, the precentage of 1s in each bit of each layer's weight and the precision assigned to each layer will be printed in the output.
To further finetune the ahcieved model, use
python cifar_finetune.py -a resnet --depth 20 --epochs 300 --lr 0.01 --schedule 150 250 --gamma 0.1 --wd 1e-4 --model checkpoints/cifar10/xxx/checkpoint.pth.tar --checkpoint checkpoints/cifar10/xxx-ft --Nbits 8 --act 4 --bin >xxx-ft.txt
The quantization scheme will be fixed throughout the finetuning process. At the end of finetuning, the model with the highest testing accuracy will be stored in both bit representation and floating-point weights. The bit representation is saved in checkpoints/cifar10/xxx-ft/best_bin.pth.tar
and the floating-point model is saved in checkpoints/cifar10/xxx-ft/best_float.pth.tar
To convert a full-precision model, use
python convert.py -a resnet --depth 20 --model checkpoints/cifar10/resnet-20/model_best.pth.tar --dict checkpoints/cifar10/xxx/checkpoint.pth.tar --checkpoint checkpoints/cifar10/xxx-mp --Nbits 8 --act 4 >xxx-mp.txt
If the path in --dict
is provided, the model will be converted to the same quantization scheme as the model specified in --dict
. Otherwise the whole model will be quantized to the precision specified in --Nbits
. The converted model will be in bit representation, and will be saved in the folder specified in --checkpoint
. We use this code to achieve the 8-bit quantized model before BSQ training, and to achieve the "train from scratch" models that are further finetuned to be compared in Table 1.
new sC