[中文README | README in Chinese]
Torch-Pruning (TP) is a versatile library for Structural Network Pruning with the following features:
- General-purpose Pruning Toolkit: TP enables structural pruning for a wide range of neural networks, including Large Language Models (LLMs), Diffusion Models, Vision Transformers, Yolov7, yolov8, FasterRCNN, SSD, KeypointRCNN, MaskRCNN, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, FCN, DeepLab, etc. Different from torch.nn.utils.prune that zeroizes parameters through masking, Torch-Pruning deploys a (non-deep) graph algorithm called DepGraph to remove parameters and channels physically.
- Reproducible Performance Benchmark and Prunability Benchmark: Currently, TP is able to prune approximately 81/85=95.3% of the models from Torchvision 0.13.1. Try this Colab Demo for quick start.
For more technical details, please refer to our CVPR'23 paper:
DepGraph: Towards Any Structural Pruning
Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang
- 2023.05.20 :rocket: LLM-Pruner: On the Structural Pruning of Large Language Models [arXiv]
- 2023.05.19 Structural Pruning for Diffusion Models [arXiv]
- 2023.04.15 Pruning and Post-training for YOLOv7 / YOLOv8
- 2023.04.21 Join our Telegram or Wechat group for casual discussions:
- Telegram: https://t.me/+NwjbBDN2ao1lZjZl
- Wechat:
Please do not hesitate to open a discussion or issue if you encounter any problems with the library or the paper.
- Structural pruning for CNNs, Transformers, Detectors, Language Models and Diffusion Models. Please refer to the Prunability Benchmark.
- High-level pruners: MagnitudePruner, BNScalePruner, GroupNormPruner, RandomPruner, etc.
- Importance Criteria: L-p Norm, Taylor, Random, BNScaling, etc.
- Dependency Graph for dependency modeling.
- Supported modules: Linear, (Transposed) Conv, Normalization, PReLU, Embedding, MultiheadAttention, nn.Parameters and customized modules.
- Supported operators: split, concatenation, skip connection, flatten, reshape, view, all element-wise ops, etc.
- Low-level pruning functions
- Benchmarks and tutorials
- A resource list for practical structrual pruning.
- A strong baseline with bags of tricks from existing methods.
- A benchmark for Torchvision compatibility (81/85=95.3%, ✔️) and timm compatibility.
- Pruning from Scratch / at Initialization.
- More high-level pruners like FisherPruner, GrowingReg, etc.
- More Transformers like Vision Transformers (:heavy_check_mark:), Swin Transformers, PoolFormers.
- Block/Layer/Depth Pruning
- Pruning benchmarks for CIFAR, ImageNet and COCO.
Torch-Pruning is compatible with PyTorch 1.x and 2.x. PyTorch 1.12.1 is recommended!
pip install torch-pruning # v1.1.8
or
git clone https://github.com/VainF/Torch-Pruning.git
Here we provide a quick start for Torch-Pruning. More explained details can be found in tutorals
In structural pruning, a Group
is defined as the minimal removable unit within deep networks. Each group consists of multiple interdependent layers that need to be pruned simultaneously in order to preserve the integrity of the resulting structures. However, deep networks often exhibit intricate dependencies among layers, posing a significant challenge for structural pruning. This work tackles this challenge by introducing an automated mechanism called DepGraph
, which enables effortless parameter grouping and facilitates pruning for a diverse range of deep networks.
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))
# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
# 3. prune all grouped layers that are coupled with model.conv1 (included).
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
group.prune()
# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object
The above example demonstrates the fundamental pruning pipeline using DepGraph. The target layer resnet.conv1 is coupled with several layers, which requires simultaneous removal in structural pruning. Let's print the group and observe how a pruning operation "triggers" other ones. In the following outputs, A => B
means the pruning operation A
triggers the pruning operation B
. group[0] refers to the pruning root in DG.get_pruning_group
.
--------------------------------
Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------
For more details about grouping, please refer to tutorials/2 - Exploring Dependency Groups
We can use DG.get_all_groups(ignored_layers, root_module_types)
to scan all groups sequentially. Each group will begin with a layer that matches a type in the "root_module_types" parameter. Note that DG.get_all_groups is only responsible for grouping and does not have any knowledge or understanding of which parameters should be pruned. Therefore, it is necessary to specify the pruning idxs using group.prune(idxs=idxs)
.
for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
# handle groups in sequential order
idxs = [2,4,6] # your pruning indices
group.prune(idxs=idxs)
print(group)
Leveraging the DependencyGraph, we developed several high-level pruners in this repository to facilitate effortless pruning. By specifying the desired channel sparsity, you can prune the entire model and fine-tune it using your own training code. For detailed information on this process, please refer to this tutorial, which shows how to implement a slimming pruner from scratch. Additionally, you can find more practical examples in benchmarks/main.py.
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True)
# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
if isinstance(imp, tp.importance.TaylorImportance):
# Taylor expansion requires gradients for importance estimation
loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
loss.backward() # before pruner.step()
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
# finetune(model)
# ...
Some pruners like BNScalePruner and GroupNormPruner require sparse training before pruning. This can be easily achieved by inserting just one line of code pruner.regularize(model)
in your training script. The pruner will update the gradient of trainable parameters.
for epoch in range(epochs):
model.train()
for i, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.cross_entropy(out, target)
loss.backward()
pruner.regularize(model) # <== for sparse learning
optimizer.step()
All high-level pruners support interactive pruning. Use pruner.step(interactive=True)
to get all groups and interactively prune them by calling group.prune()
. This feature is useful if you want to control/monitor the pruning process.
for i in range(iterative_steps):
for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
print(group)
# do whatever you like with the group
dep, idxs = group[0] # get the idxs
target_module = dep.target.module # get the root module
pruning_fn = dep.handler # get the pruning function
# Don't forget to prune the group
group.prune()
# group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
# finetune(model)
# ...
With DepGraph, it is easy to design some "group-level" criteria to estimate the importance of a whole group rather than a single layer. In Torch-pruning, all pruners work in the group level.
The following script saves the whole model object (structure+weights) as a 'model.pth'.
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model
Experimental Features: Re-create pruned models from unpruned ones using tp.state_dict
and tp.load_state_dict
.
# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')
# create a new model, e.g. resnet18
new_model = resnet18().eval()
# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
print(new_model) # This will be a pruned model.
Refer to tests/test_serialization.py for an ViT example. In this example, we will prune the model and modify some attributes like model.hidden_dims
.
While it is possible to manually prune your model using low-level functions, this approach can be quite laborious, as it requires careful management of the associated dependencies. As a result, we recommend utilizing the aforementioned high-level pruners to streamline the pruning process.
tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )
# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
...
The following pruning functions are available:
'prune_conv_out_channels',
'prune_conv_in_channels',
'prune_depthwise_conv_out_channels',
'prune_depthwise_conv_in_channels',
'prune_batchnorm_out_channels',
'prune_batchnorm_in_channels',
'prune_linear_out_channels',
'prune_linear_in_channels',
'prune_prelu_out_channels',
'prune_prelu_in_channels',
'prune_layernorm_out_channels',
'prune_layernorm_in_channels',
'prune_embedding_out_channels',
'prune_embedding_in_channels',
'prune_parameter_out_channels',
'prune_parameter_in_channels',
'prune_multihead_attention_out_channels',
'prune_multihead_attention_in_channels',
'prune_groupnorm_out_channels',
'prune_groupnorm_in_channels',
'prune_instancenorm_out_channels',
'prune_instancenorm_in_channels',
Please refer to tests/test_customized_layer.py.
Our results on {ResNet-56 / CIFAR-10 / 2.00x}
Method | Base (%) | Pruned (%) |
|
Speed Up |
---|---|---|---|---|
NIPS [1] | - | - | -0.03 | 1.76x |
Geometric [2] | 93.59 | 93.26 | -0.33 | 1.70x |
Polar [3] | 93.80 | 93.83 | +0.03 | 1.88x |
CP [4] | 92.80 | 91.80 | -1.00 | 2.00x |
AMC [5] | 92.80 | 91.90 | -0.90 | 2.00x |
HRank [6] | 93.26 | 92.17 | -0.09 | 2.00x |
SFP [7] | 93.59 | 93.36 | +0.23 | 2.11x |
ResRep [8] | 93.71 | 93.71 | +0.00 | 2.12x |
Ours-L1 | 93.53 | 92.93 | -0.60 | 2.12x |
Ours-BN | 93.53 | 93.29 | -0.24 | 2.12x |
Ours-Group | 93.53 | 93.77 | +0.38 | 2.13x |
Please refer to benchmarks for more details.
LLM-Pruner: On the Structural Pruning of Large Language Models [Project] [arXiv]
Xinyin Ma, Gongfan Fang, Xinchao Wang
Structural Pruning for Diffusion Models [Project] [arxiv]
Gongfan Fang, Xinyin Ma, Xinchao Wang
@inproceedings{fang2023depgraph,
title={Depgraph: Towards any structural pruning},
author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={16091--16101},
year={2023}
}