VainF/Torch-Pruning

Functionality to add rounding of filters number for pruning

Serjio42 opened this issue ยท 8 comments

Hi. I think it will be useful to add functionality to round number of pruned channels to provided number (32 or 16, for example). I've made it locally in prune/strategy.py script. It really accelerate inference speed!
If it can be useful to others, I can try to make pull request with this functionality this week. Any thoughts?

VainF commented

Hi @Serjio42 . It sounds great. Please make a new pull request for that.

HI, @VainF . I've made a pull request. Please look if it is OK.

VainF commented

Hello, I am a little confused about the speed-up ability of rounding. Does it mean a round channel number is more device-friendly than a non-round one?

Yes. Nothing strange with that if you look at all the industry-known architectures, all of them strive to use number of channels as 32, 64, 128, 144, 192, etc.
In fact, I've conducted couple of experiments with pruning MobileNet_v2 using your repo. Inference time tripled (!) after starting the default pruning process. But when I started to round number of channels to 16, inference time didn't degenerate anymore. Inference was done with onnx runtime on my computer CPU.
I think it is because of inner binary devices design.

VainF commented

Thanks! I will try your code and conduct some experiments to verify the benefits of rounding.

VainF commented

Hi @Serjio42 I have tried the strategy with rounding. Here is the inference time of MobileNetv2 with [16 x 3 x 32 x32] inputs.

GPU:

before pruning: inference time=0.014485 s, parameters=3504872
w/o rounding: inference time=0.007839 s, parameters=1969470
w/ rounding: inference time=0.008662 s, parameters=1967864

CPU:

before pruning: inference time=0.267591 s, parameters=3504872
w/o rounding: inference time=0.154733 s, parameters=1969470
w/ rounding: inference time=0.149170 s, parameters=1967864

It seems that whether rounding can improve the inference time depends on the hardware.
Here is my test script:

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
import torch_pruning as tp
import time

def measure_inference_time(net, input, repeat=100):
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(repeat):
        model(input)
        torch.cuda.synchronize()
    end = time.perf_counter()
    return (end-start) / repeat

device = torch.device('cpu')
repeat = 100

# w/o rounding
model = mobilenet_v2(pretrained=True).eval()
fake_input = torch.randn(16,3,224,224)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model)))

model = mobilenet_v2(pretrained=True).eval()
strategy = tp.strategy.L1Strategy()
DG = tp.DependencyGraph()
fake_input = fake_input.cpu()
DG.build_dependency(model, example_inputs=fake_input)
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        pruning_idxs = strategy(m.weight, amount=0.2)
        pruning_plan = DG.get_pruning_plan( m, tp.prune_conv, idxs=pruning_idxs )
        pruning_plan.exec()
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model)))
    
# w/ rounding
model = mobilenet_v2(pretrained=True).eval()
strategy = tp.strategy.L1Strategy()
DG = tp.DependencyGraph()
fake_input = fake_input.cpu()
DG.build_dependency(model, example_inputs=fake_input)
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        pruning_idxs = strategy(m.weight, amount=0.2, round_to=8)
        pruning_plan = DG.get_pruning_plan( m, tp.prune_conv, idxs=pruning_idxs )
        pruning_plan.exec()
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model)))
VainF commented

I think rounding is useful for CPU deployment. Maybe we should implement the rounding operation as a function to make the strategy more clear. Let's merge it first and move the rounding functionality to a new function.

Refactoring of rounding was the thing that I've thought about, too :) Function looks good, thanks!
My results on single CPU core with 3.0 GHz, with the default mobilenet_v2 is about 5 ms per sample.
What if you try to use single image in batch?
Maybe the reason is that you measure not with onnx or not with only one core/thread...
I've made steps like that, If you want to reproduce:

import onnx
import onnxruntime as rt
import time

x = torch.randn(batch_size, 3, 224, 224).to(device)
torch.onnx.export(model, x, save_onnx_path,
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes={'input': {0: 'batch_size'},
                                 'output': {0: 'batch_size'}},
                      opset_version=11
                      )
onnx_model = onnx.load(save_onnx_path)

so = rt.SessionOptions()
so.intra_op_num_threads = 1
sess = rt.InferenceSession(save_onnx_path, providers=['CPUExecutionProvider'], sess_options=so)
inputNames = [k.name for k in sess.get_inputs()]
outputNames = [k.name for k in sess.get_outputs()]

all_iters = 0
all_times = 0
for img in samples:
        all_iters += 1
        start = time.time()
        result = sess.run(outputNames, {inputNames[0]: tensor})[0]
        end = time.time()
        all_times += end - start

print('Averaged inference time=', all_times / all_iters)

Here is the pruning steps and corresponding inference times on graph:
image