Functionality to add rounding of filters number for pruning
Serjio42 opened this issue ยท 8 comments
Hi. I think it will be useful to add functionality to round number of pruned channels to provided number (32 or 16, for example). I've made it locally in prune/ script. It really accelerate inference speed!
If it can be useful to others, I can try to make pull request with this functionality this week. Any thoughts?
Hello, I am a little confused about the speed-up ability of rounding. Does it mean a round channel number is more device-friendly than a non-round one?
Yes. Nothing strange with that if you look at all the industry-known architectures, all of them strive to use number of channels as 32, 64, 128, 144, 192, etc.
In fact, I've conducted couple of experiments with pruning MobileNet_v2 using your repo. Inference time tripled (!) after starting the default pruning process. But when I started to round number of channels to 16, inference time didn't degenerate anymore. Inference was done with onnx runtime on my computer CPU.
I think it is because of inner binary devices design.
Thanks! I will try your code and conduct some experiments to verify the benefits of rounding.
Hi @Serjio42 I have tried the strategy with rounding. Here is the inference time of MobileNetv2 with [16 x 3 x 32 x32] inputs.
before pruning: inference time=0.014485 s, parameters=3504872
w/o rounding: inference time=0.007839 s, parameters=1969470
w/ rounding: inference time=0.008662 s, parameters=1967864
before pruning: inference time=0.267591 s, parameters=3504872
w/o rounding: inference time=0.154733 s, parameters=1969470
w/ rounding: inference time=0.149170 s, parameters=1967864
It seems that whether rounding can improve the inference time depends on the hardware.
Here is my test script:
import sys, os
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
import torch_pruning as tp
import time
def measure_inference_time(net, input, repeat=100):
start = time.perf_counter()
for _ in range(repeat):
end = time.perf_counter()
return (end-start) / repeat
device = torch.device('cpu')
repeat = 100
# w/o rounding
model = mobilenet_v2(pretrained=True).eval()
fake_input = torch.randn(16,3,224,224)
model =
fake_input =
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model)))
model = mobilenet_v2(pretrained=True).eval()
strategy = tp.strategy.L1Strategy()
DG = tp.DependencyGraph()
fake_input = fake_input.cpu()
DG.build_dependency(model, example_inputs=fake_input)
for m in model.modules():
if isinstance(m, nn.Conv2d):
pruning_idxs = strategy(m.weight, amount=0.2)
pruning_plan = DG.get_pruning_plan( m, tp.prune_conv, idxs=pruning_idxs )
model =
fake_input =
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model)))
# w/ rounding
model = mobilenet_v2(pretrained=True).eval()
strategy = tp.strategy.L1Strategy()
DG = tp.DependencyGraph()
fake_input = fake_input.cpu()
DG.build_dependency(model, example_inputs=fake_input)
for m in model.modules():
if isinstance(m, nn.Conv2d):
pruning_idxs = strategy(m.weight, amount=0.2, round_to=8)
pruning_plan = DG.get_pruning_plan( m, tp.prune_conv, idxs=pruning_idxs )
model =
fake_input =
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model)))
I think rounding is useful for CPU deployment. Maybe we should implement the rounding operation as a function to make the strategy more clear. Let's merge it first and move the rounding functionality to a new function.
Refactoring of rounding was the thing that I've thought about, too :) Function looks good, thanks!
My results on single CPU core with 3.0 GHz, with the default mobilenet_v2
is about 5 ms per sample.
What if you try to use single image in batch?
Maybe the reason is that you measure not with onnx or not with only one core/thread...
I've made steps like that, If you want to reproduce:
import onnx
import onnxruntime as rt
import time
x = torch.randn(batch_size, 3, 224, 224).to(device)
torch.onnx.export(model, x, save_onnx_path,
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}},
onnx_model = onnx.load(save_onnx_path)
so = rt.SessionOptions()
so.intra_op_num_threads = 1
sess = rt.InferenceSession(save_onnx_path, providers=['CPUExecutionProvider'], sess_options=so)
inputNames = [ for k in sess.get_inputs()]
outputNames = [ for k in sess.get_outputs()]
all_iters = 0
all_times = 0
for img in samples:
all_iters += 1
start = time.time()
result =, {inputNames[0]: tensor})[0]
end = time.time()
all_times += end - start
print('Averaged inference time=', all_times / all_iters)
Here is the pruning steps and corresponding inference times on graph: