PyTorch implementation of training 1-bit Wide ResNets from this paper:
Training wide residual networks for deployment using a single bit for each weight by Mark D. McDonnell at ICLR 2018
The idea is very simple but surprisingly effective for training ResNets with binary weights. Here is the proposed weight parameterization as PyTorch autograd function:
class ForwardSign(torch.autograd.Function):
def forward(ctx, w):
return math.sqrt(2. / (w.shape[1] * w.shape[2] * w.shape[3])) * w.sign()
def backward(ctx, g):
return g
On forward, we take sign of the weights and scale it by He-init constant. On backward, we propagate gradient without changes. WRN-20-10 trained with such parameterization is only slightly off from it's full precision variant, here is what I got myself with this code on CIFAR-100:
network | accuracy (5 runs mean +- std) | checkpoint (Mb) |
WRN-20-10 | 80.5 +- 0.24 | 205 Mb |
WRN-20-10-1bit | 80.0 +- 0.26 | 3.5 Mb |
Here are the differences with WRN code
- BatchNorm has no affine weight and bias parameters
- First layer has 16 * width channels
- Last fc layer is removed in favor of 1x1 conv + F.avg_pool2d
- Downsample is done by F.avg_pool2d + instead of strided conv
- SGD with cosine annealing and warm restarts
I used PyTorch 0.4.1 and Python 3.6 to run the code.
Reproduce WRN-20-10 with 1-bit training on CIFAR-100:
python --binarize --save ./logs/WRN-20-10-1bit_$RANDOM --width 10 --dataset CIFAR100
Convergence plot (train error in dash):
I've also put 3.5 Mb checkpoint with binary weights packed with np.packbits
, and a very short script to evaluate it:
python --checkpoint wrn20-10-1bit-packed.pth.tar --width 10 --dataset CIFAR100
S3 url to checkpoint: