/onnx-simplifier

Simplify your onnx model

Primary LanguagePythonApache License 2.0Apache-2.0

ONNX Simplifier

PyPI version PyPI pyversions PyPI license PRs Welcome

ONNX is great, but sometimes too complicated.

Background

One day I wanted to export the following simple reshape operation to ONNX:

import torch


class OnlyReshape(torch.nn.Module):
    def __init__(self):
        super(OnlyReshape, self).__init__()

    def forward(self, x):
        return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2]))


net = OnlyReshape()
model_name = 'only_reshape.onnx'
dummy_input = torch.randn(2, 3, 4, 5)
torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])

Input shape in ONNX is static, so what I expected is

simple_reshape

However, I got the following complicated model even after polishing it:

complicated_reshape

Moreover, there are also operations performed on weights in some ONNX models (e.g., this). As pointed out in onnx/onnx#1758 and JDAI-CV/DNNLibrary#26, they can all be eliminated by offline computation.

Our solution

ONNX Simplifier is presented to simplify the ONNX model. It infers the whole computation graph and then replaces the redundant operators with their constant outputs.

Just install it via pip (Python >= 3.5)

pip3 install onnx-simplifier

Then

python3 -m onnxsim input_model output_model

Results

An overall comparison between a complicated model and its simplified version:

Comparison between old model and new model