EIDOSLAB/simplify

How to speed up the inference time?

zeyuanyin opened this issue · 2 comments

My code is:

model = models.resnet18(pretrained=True)
dummy_input = torch.zeros(64, 3, 224, 224)  # Tensor shape is that of a standard input for the given model
simplified_model = simplify(model, dummy_input)

def run_model(model, dummy_input):
    start = time.time()
    for _ in tqdm(range(100)):
        model(dummy_input)
    end = time.time()
    print("Time taken: ", end - start)

print("Original model")
run_model(model, dummy_input)

print("Simplified model")
run_model(simplified_model, dummy_input)

The output is:

Original model
100%|██████████| 100/100 [01:37<00:00,  1.03it/s]
Time taken:  97.2786557674408
Simplified model
100%|██████████| 100/100 [01:39<00:00,  1.01it/s]
Time taken:  99.11441326141357

It seems there is no acceleration in inference. Maybe there are not zero channels to be pruned in the pre-trained model.
I have a question about how to speed up the inference time.
Should I use the prune.ln_structured in torch.nn.utils.prune to prune the pre-trained model at first?
I think this is a good project to do the following work behind the torch prune.
Can you provide an entire example for accelerated inference?

@zeyuanyin HI, indeed there may be no zeroed channels and you could test by trying the prune.ln_structured method on a newly initialized network. Here is a sample code:

model = resnet18(pretrained=False)
model.eval()
model_src = deepcopy(model)

for name, module in model.named_modules():    
    if isinstance(module, nn.Conv2d):
        prune.random_structured(module, 'weight', amount=0.8, dim=0)
        prune.remove(module, 'weight')

y_src = model(x)
zeros = torch.zeros(1, *x.shape[1:])

simplify.simplify(model, zeros, fuse_bn=fuse_bn, training=True)
y_prop = model(x)

return torch.equal(y_src.argmax(dim=1), y_prop.argmax(dim=1))

https://github.com/EIDOSLAB/simplify/tree/18ec7921dc04bb4fb604d994e257f087fde096bf/benchmark Here you can check some old benchmark functions that we wrote while developing (and should port to the newer version as soon as we have some time), they should still work.

Let me know if you have any more questions.

@AndreaBrg Thank you so much!

It works well on structured pruning prune.random_structuredand prune.ln_structured.
I post my code for other users' reference:

def run_model(model, dummy_input):
    start = time.time()
    for _ in tqdm(range(10)):
        y= model(dummy_input)
    end = time.time()
    print("Time taken: ", end - start)
    return y


dummy_input = torch.zeros(64, 3, 224, 224)  # Tensor shape is that of a standard input for the given model
model = models.resnet18(pretrained=True)
model.eval()


# for name, module in model.named_modules():
#     if isinstance(module, nn.Conv2d):
#         prune.random_structured(module, 'weight', amount=0.8, dim=0)
#         prune.remove(module, 'weight')


def prune_model_l1_structured(model, layer_type, proportion):
    for module in model.modules():
        if isinstance(module, layer_type):
            prune.ln_structured(module, 'weight', proportion, n=1, dim=0)
            prune.remove(module, 'weight')
    return model

prune_model_l1_structured(model, nn.Conv2d, 0.8)
# prune_model_l1_structured(model, nn.Linear, 0.3) # the last layer of resnet is not pruned, classification layer


print("Original model")
y_src = run_model(model, dummy_input)
# y_src = model(dummy_input)

zeros = torch.zeros(1, *dummy_input.shape[1:])
simplify(model, zeros)


print("Pruned model")
y_prop = run_model(model, dummy_input)
# y_prop = model(dummy_input)

# check if the output is the same
print(torch.equal(y_src.argmax(dim=1), y_prop.argmax(dim=1)))

Terminal output:
image

image