Simplification of pruned models for accelerated inference.
Simplify can be installed using pip:
pip3 install torch-simplify
or if you want to run the latest version of the code, you can install from git:
git clone https://github.com/EIDOSlab/simplify
cd simplify
pip3 install -r requirements.txt
from torchvision.models import resnet18
from simplify import fuse
model = resnet18()
model.eval()
bn_folding = ... # List of pairs (conv, bn) to fuse in a single layer
model = fuse(model, bn_folding)
The propagate module is used to remove the non-zero bias from zeroed-out neurons in order to be able to remove them.
import torch
from simplify import propagate_bias
from torchvision.models import resnet18
zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ... # List of layers for which the bias should not be propagated
propagate_bias(model, zeros, pinned_out)
The remove module is used to remove actually remove the zeroed neurons from the model architecture.
import torch
from simplify import remove_zeroed
from torchvision.models import resnet18
zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ... # List of layers in which the output should not change shape
remove_zeroed(model, zeros, pinned_out)
We also provide a set of utilities used to define bn_folding
and pinned_out
for standard PyTorch models.
from torchvision.models import resnet18
from utils import get_bn_folding, get_pinned_out
model = resnet18()
bn_folding = get_bn_folding(model)
pinned_out = get_pinned_out(model)
Tests
Update timestamp 08/10/2021 14:26:25
Random structured pruning amount = 50.0%
Architecture | Dense time | Pruned time | Simplified time |
---|---|---|---|
alexnet | 7.58ms ± 0.29 | 7.55ms ± 0.28 | 2.95ms ± 0.02 |
densenet121 | 36.41ms ± 4.88 | 34.31ms ± 3.85 | 21.87ms ± 1.45 |
googlenet | 15.44ms ± 3.19 | 13.68ms ± 0.09 | 10.31ms ± 0.82 |
inception_v3 | 25.29ms ± 7.31 | 21.68ms ± 2.90 | 13.22ms ± 2.23 |
mnasnet1_0 | 17.66ms ± 0.57 | 13.64ms ± 0.13 | 11.59ms ± 0.07 |
mobilenet_v3_large | 13.74ms ± 0.67 | 12.18ms ± 0.46 | 11.95ms ± 0.21 |
resnet50 | 24.39ms ± 4.48 | 26.19ms ± 5.84 | 18.21ms ± 1.98 |
resnext101_32x8d | 76.11ms ± 15.79 | 77.35ms ± 20.04 | 65.68ms ± 16.41 |
shufflenet_v2_x2_0 | 18.07ms ± 2.23 | 14.32ms ± 0.21 | 13.06ms ± 0.08 |
squeezenet1_1 | 4.50ms ± 0.06 | 4.39ms ± 0.05 | 4.09ms ± 0.50 |
vgg19_bn | 40.41ms ± 12.13 | 38.56ms ± 10.72 | 12.39ms ± 0.19 |
wide_resnet101_2 | 79.40ms ± 25.57 | 82.86ms ± 22.47 | 60.16ms ± 10.77 |
✔️: all good
❌: gives different results
🤬: an exception occurred
🤷♂️: test skipped due to failing of the previous one
Fuse BatchNorm
Update timestamp 06/10/2021 20:26:15
Architecture | BatchNorm Folding | Bias Propagation | Simplification |
---|---|---|---|
alexnet | ✔️ | ✔️ | ✔️ |
densenet121 | ✔️ | ✔️ | ✔️ |
googlenet | ✔️ | ✔️ | ✔️ |
inception_v3 | ✔️ | ✔️ | ✔️ |
mnasnet1_0 | ✔️ | ✔️ | ✔️ |
mobilenet_v3_large | ✔️ | ✔️ | ✔️ |
resnet50 | ✔️ | ✔️ | ✔️ |
resnext101_32x8d | ✔️ | ✔️ | ✔️ |
shufflenet_v2_x2_0 | ✔️ | ✔️ | ✔️ |
squeezenet1_1 | ✔️ | ✔️ | ✔️ |
vgg19_bn | ✔️ | ✔️ | ✔️ |
wide_resnet101_2 | ✔️ | ✔️ | ✔️ |
Keep BatchNorm
Update timestamp 06/10/2021 20:36:11
Architecture | BatchNorm Folding | Bias Propagation | Simplification |
---|---|---|---|
alexnet | ✔️ | ✔️ | ✔️ |
densenet121 | ✔️ | ✔️ | ✔️ |
googlenet | ✔️ | ✔️ | ✔️ |
inception_v3 | ✔️ | ✔️ | ✔️ |
mnasnet1_0 | ✔️ | ✔️ | ✔️ |
mobilenet_v3_large | ✔️ | ✔️ | ✔️ |
resnet50 | ✔️ | ✔️ | ✔️ |
resnext101_32x8d | ✔️ | ✔️ | ✔️ |
shufflenet_v2_x2_0 | ✔️ | ✔️ | ✔️ |
squeezenet1_1 | ✔️ | ✔️ | ✔️ |
vgg19_bn | ✔️ | ✔️ | ✔️ |
wide_resnet101_2 | ✔️ | ✔️ | ✔️ |