PyTorch implementation of training 1-bit Wide ResNets from this paper:
Training wide residual networks for deployment using a single bit for each weight by Mark D. McDonnell at ICLR 2018
https://openreview.net/forum?id=rytNfI1AZ
https://arxiv.org/abs/1802.08530
The idea is very simple but surprisingly effective for training ResNets with binary weights. Here is the proposed weight parameterization as PyTorch autograd function:
class ForwardSign(torch.autograd.Function):
@staticmethod
def forward(ctx, w):
return math.sqrt(2. / (w.shape[1] * w.shape[2] * w.shape[3])) * w.sign()
@staticmethod
def backward(ctx, g):
return g
On forward, we take sign of the weights and scale it by He-init constant. On backward, we propagate gradient without changes. WRN-20-10 trained with such parameterization is only slightly off from it's full precision variant, here is what I got myself with this code on CIFAR-100:
network | accuracy (5 runs mean +- std) | checkpoint (Mb) |
---|---|---|
WRN-20-10 | 80.5 +- 0.24 | 205 Mb |
WRN-20-10-1bit | 80.0 +- 0.26 | 3.5 Mb |
Here are the differences with WRN code https://github.com/szagoruyko/wide-residual-networks:
- BatchNorm has no affine weight and bias parameters
- First layer has 16 * width channels
- Last fc layer is removed in favor of 1x1 conv + F.avg_pool2d
- Downsample is done by F.avg_pool2d + torch.cat instead of strided conv
- SGD with cosine annealing and warm restarts
I used PyTorch 0.4.1 and Python 3.6 to run the code.
Reproduce WRN-20-10 with 1-bit training on CIFAR-100:
python main.py --binarize --save ./logs/WRN-20-10-1bit_$RANDOM --width 10 --dataset CIFAR100
Convergence plot (train error in dash):
I've also put 3.5 Mb checkpoint with binary weights packed with np.packbits
, and a very short script to evaluate it:
python evaluate_packed.py --checkpoint wrn20-10-1bit-packed.pth.tar --width 10 --dataset CIFAR100
S3 url to checkpoint: https://s3.amazonaws.com/modelzoo-networks/wrn20-10-1bit-packed.pth.tar