It's okay to use with pytorch 1.7.0, but not okay with 1.3.1
This is a profiler to count the number of MACs / FLOPs of PyTorch models based on torch.jit.trace
.
- It is more general than ONNX-based profilers as some operations in PyTorch are not supported by ONNX for now.
- It is more accurate than hook-based profilers as they cannot profile operations within
torch.nn.Module
.
pip install torchprofile
You should first define your PyTorch model and its (dummy) input:
import torch
from torchvision.models import resnet18
model = resnet18()
inputs = torch.randn(1, 3, 224, 224)
You can then measure the number of MACs using profile_macs
:
from torchprofile import profile_macs
macs, model_output = profile_macs(model, inputs)
Here inputs
only supports positional input and keyword arguments are not supported.
This repository is released under the MIT license. See LICENSE for additional details.