
Simplification of pruned models for accelerated inference. To cite this Original Software Publication: https://www.sciencedirect.com/science/article/pii/S2352711021001576

Primary LanguagePythonBSD 3-Clause "New" or "Revised" LicenseBSD-3-Clause



Simplification of pruned models for accelerated inference.


Simplify can be installed using pip:

pip3 install torch-simplify

or if you want to run the latest version of the code, you can install from git:

git clone https://github.com/EIDOSlab/simplify
cd simplify
pip3 install -r requirements.txt

Example usage

from torchvision.models import resnet18
from simplify import fuse

model = resnet18()
bn_folding = ...  # List of pairs (conv, bn) to fuse in a single layer
model = fuse(model, bn_folding)


The propagate module is used to remove the non-zero bias from zeroed-out neurons in order to be able to remove them.

import torch
from simplify import propagate_bias
from torchvision.models import resnet18

zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ...  # List of layers for which the bias should not be propagated
propagate_bias(model, zeros, pinned_out)


The remove module is used to remove actually remove the zeroed neurons from the model architecture.

import torch
from simplify import remove_zeroed
from torchvision.models import resnet18

zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ...  # List of layers in which the output should not change shape
remove_zeroed(model, zeros, pinned_out)


We also provide a set of utilities used to define bn_folding and pinned_out for standard PyTorch models.

from torchvision.models import resnet18
from utils import get_bn_folding, get_pinned_out

model = resnet18()
bn_folding = get_bn_folding(model)
pinned_out = get_pinned_out(model)

Inference time benchmarks

Evaluation mode (fuses BatchNorm)

Update timestamp 08/10/2021 14:26:25

Random structured pruning amount = 50.0%

Architecture Dense time Pruned time Simplified time
alexnet 7.58ms ± 0.29 7.55ms ± 0.28 2.95ms ± 0.02
densenet121 36.41ms ± 4.88 34.31ms ± 3.85 21.87ms ± 1.45
googlenet 15.44ms ± 3.19 13.68ms ± 0.09 10.31ms ± 0.82
inception_v3 25.29ms ± 7.31 21.68ms ± 2.90 13.22ms ± 2.23
mnasnet1_0 17.66ms ± 0.57 13.64ms ± 0.13 11.59ms ± 0.07
mobilenet_v3_large 13.74ms ± 0.67 12.18ms ± 0.46 11.95ms ± 0.21
resnet50 24.39ms ± 4.48 26.19ms ± 5.84 18.21ms ± 1.98
resnext101_32x8d 76.11ms ± 15.79 77.35ms ± 20.04 65.68ms ± 16.41
shufflenet_v2_x2_0 18.07ms ± 2.23 14.32ms ± 0.21 13.06ms ± 0.08
squeezenet1_1 4.50ms ± 0.06 4.39ms ± 0.05 4.09ms ± 0.50
vgg19_bn 40.41ms ± 12.13 38.56ms ± 10.72 12.39ms ± 0.19
wide_resnet101_2 79.40ms ± 25.57 82.86ms ± 22.47 60.16ms ± 10.77

Status of torchvision.models

✔️: all good

❌: gives different results

🤬: an exception occurred

🤷‍♂️: test skipped due to failing of the previous one

Fuse BatchNorm

Update timestamp 06/10/2021 20:26:15

Architecture BatchNorm Folding Bias Propagation Simplification
alexnet ✔️ ✔️ ✔️
densenet121 ✔️ ✔️ ✔️
googlenet ✔️ ✔️ ✔️
inception_v3 ✔️ ✔️ ✔️
mnasnet1_0 ✔️ ✔️ ✔️
mobilenet_v3_large ✔️ ✔️ ✔️
resnet50 ✔️ ✔️ ✔️
resnext101_32x8d ✔️ ✔️ ✔️
shufflenet_v2_x2_0 ✔️ ✔️ ✔️
squeezenet1_1 ✔️ ✔️ ✔️
vgg19_bn ✔️ ✔️ ✔️
wide_resnet101_2 ✔️ ✔️ ✔️

Keep BatchNorm

Update timestamp 06/10/2021 20:36:11

Architecture BatchNorm Folding Bias Propagation Simplification
alexnet ✔️ ✔️ ✔️
densenet121 ✔️ ✔️ ✔️
googlenet ✔️ ✔️ ✔️
inception_v3 ✔️ ✔️ ✔️
mnasnet1_0 ✔️ ✔️ ✔️
mobilenet_v3_large ✔️ ✔️ ✔️
resnet50 ✔️ ✔️ ✔️
resnext101_32x8d ✔️ ✔️ ✔️
shufflenet_v2_x2_0 ✔️ ✔️ ✔️
squeezenet1_1 ✔️ ✔️ ✔️
vgg19_bn ✔️ ✔️ ✔️
wide_resnet101_2 ✔️ ✔️ ✔️