[Documentation & Tutorials] [FAQ]
Torch-Pruning (TP) is a library for structural pruning with the following features:
- General-purpose Pruning Toolkit: TP enables structural pruning for a wide range of deep neural networks, including Large Language Models (LLMs), Diffusion Models, Yolov7, yolov8, Vision Transformers, Swin Transformers, BERT, FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc. Different from torch.nn.utils.prune that zeroizes parameters through masking, Torch-Pruning deploys an algorithm called DepGraph to remove parameters physically. Check our Paper List for more details.
- Examples: Play around with off-the-shelf models from Huggingface Transformers, Timm, Torchvision, Yolo, etc.
- Benchmark: Reproduce the our results in the DepGraph paper.
For more technical details, please refer to our CVPR'23 paper:
DepGraph: Towards Any Structural Pruning
Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang
Learning and Vision Lab, National University of Singapore
- 2023.09.06 Pruning & Finetuning Examples for Vision Transformers, Swin Transformers, Bert.
- 2023.07.19 🚀 Support LLaMA, LLaMA-2, Vicuna, Baichuan, Bloom in LLM-Pruner
- 2023.05.20 🚀 LLM-Pruner: On the Structural Pruning of Large Language Models [arXiv]
- 2023.05.19 Structural Pruning for Diffusion Models [arXiv]
- 2023.04.15 Pruning and Post-training for YOLOv7 / YOLOv8
- High-level Pruners: MetaPruner, MagnitudePruner, BNScalePruner, GroupNormPruner, GrowingRegPruner, RandomPruner, etc. A paper list is available on our wiki page.
- Dependency Graph for automatic structural pruning
- Low-level pruning functions
- Supported Importance Criteria: L-p Norm, Taylor, Random, BNScaling, etc.
- Supported modules: Linear, (Transposed) Conv, Normalization, PReLU, Embedding, MultiheadAttention, nn.Parameters, customized modules and nested/composed modules.
- Supported operators: split, concatenation, skip connection, flatten, reshape, view, all element-wise ops, etc.
- Benchmarks, Tutorials and Examples
Please do not hesitate to open a discussion or issue if you encounter any problems with the library or the paper.
Or Join our Discord or WeChat group for a chat:
Torch-Pruning is compatible with both PyTorch 1.x and 2.x versions. However, it is highly recommended to use PyTorch 1.12.1 or higher.
pip install torch-pruning
or
git clone https://github.com/VainF/Torch-Pruning.git
Here we provide a quick start for Torch-Pruning. More explained details can be found in Tutorals
In structural pruning, a "Group" is defined as the minimal unit that can be removed within deep networks. These groups are composed of multiple layers that are interdependent and need to be pruned together in order to maintain the integrity of the resulting structures. However, deep networks often have complex dependencies among their layers, making structural pruning a challenging task. This work addresses this challenge by introducing an automated mechanism called "DepGraph." DepGraph allows for seamless parameter grouping and facilitates pruning in various types of deep networks.
Please ensure that your model is set up to enable AutoGrad without torch.no_grad
or .requires_grad=False
.
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
# 1. Build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))
# 2. Group coupled layers for model.conv1
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
# 3. Prune grouped layers altogether
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
group.prune()
# 4. Save & Load
model.zero_grad() # clear gradients
torch.save(model, 'model.pth') # We can not use .state_dict as the model structure is changed.
model = torch.load('model.pth') # load the pruned model
The above example demonstrates the basic pruning pipeline with DepGraph. The target layer resnet.conv1
is coupled with multiple layers, necessitating their simultaneous removal during structural pruning. To observe the cascading effect of pruning operations, we can print the groups and observe how one pruning operation can "trigger" others. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The group[0] refers to the pruning root in DG.get_pruning_group. For more details about grouping, please refer to Wiki - DepGraph & Group.
--------------------------------
Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------
We can use DG.get_all_groups(ignored_layers, root_module_types)
to scan and prune all groups sequentially. Each group will begin with a layer that matches one type in the root_module_types
parameter. Note that DG.get_all_groups
is only responsible for grouping and does not have any knowledge or understanding of which parameters should be pruned. Therefore, it is necessary to specify the pruning idxs using group.prune(idxs=idxs)
. This feature is useful when you want to implement your own pruning algorithms.
for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
# handle groups in sequential order
idxs = [2,4,6] # your pruning indices
group.prune(idxs=idxs)
print(group)
With DepGraph, we developed several high-level pruners in this repository to facilitate effortless pruning. By specifying the desired channel sparsity, a pruner will scan all prunable groups, estimate the importance, prune the entire model, and fine-tune it using your own training code. For detailed information on this process, please refer to this tutorial, which shows how to implement a slimming pruner from scratch. Additionally, a more practical example is available in benchmarks/main.py.
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)
# 1. Importance criterion
imp = tp.importance.GroupTaylorImportance() # or GroupNormImportance(p=2), GroupHessianImportance(), etc.
# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
model,
example_inputs,
importance=imp,
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
# ch_sparsity_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized sparsity for layers or blocks
ignored_layers=ignored_layers,
)
# 3. Prune & finetune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
if isinstance(imp, tp.importance.GroupTaylorImportance):
# Taylor expansion requires gradients for importance estimation
# A dummy loss, please replace it with your loss function and data!
loss = model(example_inputs).sum()
loss.backward() # before pruner.step()
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune the pruned model here
# finetune(model)
# ...
With the option of global pruning (global_pruning=True
), adaptive sparsity will be allocated to different layers based on their global rank of importance. While this strategy can offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. If you're not very familiar with pruning, it's recommended to begin with global_pruning=False
.
Some pruners like BNScalePruner and GroupNormPruner support sparse training. This can be easily achieved by inserting one line of code pruner.regularize(model)
just between loss.backward()
and optimizer.step()
. The pruner will update the gradient of trainable parameters.
for epoch in range(epochs):
model.train()
for i, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.cross_entropy(out, target)
loss.backward() # after loss.backward()
pruner.regularize(model) # <== for sparse training
optimizer.step() # before optimizer.step()
All high-level pruners offer support for interactive pruning. You can utilize the method pruner.step(interactive=True)
to retrieve all the groups and interactively prune them by calling group.prune()
. This feature is particularly useful when you want to have control over or monitor the pruning process.
for i in range(iterative_steps):
for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
print(group)
# do whatever you like with the group
dep, idxs = group[0] # get the idxs
target_module = dep.target.module # get the root module
pruning_fn = dep.handler # get the pruning function
group.prune()
# group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
# finetune(model)
# ...
With DepGraph, it is easy to design some "group-level" criteria to estimate the importance of a whole group rather than a single layer. This feature can be also used to sparsify coupled layers, making all the to-be-pruned parameters consistently unimportant. In Torch-pruning, all pruners work at the group level. Check the following results to see how grouping improves the performance of pruning.
- Pruning a ResNet50 pre-trained on ImageNet-1K without fine-tuning.
- Pruning a Vision Transformer pre-trained on ImageNet-1K without fine-tuning.
In some implementation, model forward might rely on some static attributes. For example in convformer_s18
of timm, we have:
class Scale(nn.Module):
"""
Scale vector by element multiplications.
"""
def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
super().__init__()
self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning
self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
def forward(self, x):
return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning
where the forward
function relies on self.shape
during forwarding. But, the true self.shape
changed after pruning, which should be manually updated.
The following script saves the whole model object (structure+weights) as a 'model.pth'.
model.zero_grad() # Remove gradients
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model
Experimental Features: Re-create pruned models from unpruned ones using tp.state_dict
and tp.load_state_dict
.
# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')
# create a new model, e.g. resnet18
new_model = resnet18().eval()
# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
Refer to tests/test_serialization.py for an ViT example. In this example, we will prune the model and modify some attributes like model.hidden_dims
.
Although it is possible to manually prune your model using low-level functions, this approach can be cumbersome and time-consuming due to the need for meticulous management of dependencies. Therefore, we strongly recommend utilizing the high-level pruners mentioned earlier to streamline and simplify the pruning process. These pruners provide a more convenient and efficient way to perform pruning on your models. To manually prune the model.conv1
of a ResNet-18, the pruning pipeline should look like this:
tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )
# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
...
The following pruning functions are available:
'prune_conv_out_channels',
'prune_conv_in_channels',
'prune_depthwise_conv_out_channels',
'prune_depthwise_conv_in_channels',
'prune_batchnorm_out_channels',
'prune_batchnorm_in_channels',
'prune_linear_out_channels',
'prune_linear_in_channels',
'prune_prelu_out_channels',
'prune_prelu_in_channels',
'prune_layernorm_out_channels',
'prune_layernorm_in_channels',
'prune_embedding_out_channels',
'prune_embedding_in_channels',
'prune_parameter_out_channels',
'prune_parameter_in_channels',
'prune_multihead_attention_out_channels',
'prune_multihead_attention_in_channels',
'prune_groupnorm_out_channels',
'prune_groupnorm_in_channels',
'prune_instancenorm_out_channels',
'prune_instancenorm_in_channels',
Please refer to examples/transformers/prune_hf_swin.py, which implements a new pruner for the customized module SwinPatchMerging
. A more simple example is available at tests/test_customized_layer.py.
Method | Base (%) | Pruned (%) |
|
Speed Up |
---|---|---|---|---|
NIPS [1] | - | - | -0.03 | 1.76x |
Geometric [2] | 93.59 | 93.26 | -0.33 | 1.70x |
Polar [3] | 93.80 | 93.83 | +0.03 | 1.88x |
CP [4] | 92.80 | 91.80 | -1.00 | 2.00x |
AMC [5] | 92.80 | 91.90 | -0.90 | 2.00x |
HRank [6] | 93.26 | 92.17 | -0.09 | 2.00x |
SFP [7] | 93.59 | 93.36 | +0.23 | 2.11x |
ResRep [8] | 93.71 | 93.71 | +0.00 | 2.12x |
Ours-L1 | 93.53 | 92.93 | -0.60 | 2.12x |
Ours-BN | 93.53 | 93.29 | -0.24 | 2.12x |
Ours-Group | 93.53 | 93.77 | +0.38 | 2.13x |
Latency test on ResNet-50, Batch Size=64.
[Iter 0] Sparsity: 0.00, MACs: 4.12 G, Params: 25.56 M, Latency: 45.22 ms +- 0.03 ms
[Iter 1] Sparsity: 0.05, MACs: 3.68 G, Params: 22.97 M, Latency: 46.53 ms +- 0.06 ms
[Iter 2] Sparsity: 0.10, MACs: 3.31 G, Params: 20.63 M, Latency: 43.85 ms +- 0.08 ms
[Iter 3] Sparsity: 0.15, MACs: 2.97 G, Params: 18.36 M, Latency: 41.22 ms +- 0.10 ms
[Iter 4] Sparsity: 0.20, MACs: 2.63 G, Params: 16.27 M, Latency: 39.28 ms +- 0.20 ms
[Iter 5] Sparsity: 0.25, MACs: 2.35 G, Params: 14.39 M, Latency: 34.60 ms +- 0.19 ms
[Iter 6] Sparsity: 0.30, MACs: 2.02 G, Params: 12.46 M, Latency: 33.38 ms +- 0.27 ms
[Iter 7] Sparsity: 0.35, MACs: 1.74 G, Params: 10.75 M, Latency: 31.46 ms +- 0.20 ms
[Iter 8] Sparsity: 0.40, MACs: 1.50 G, Params: 9.14 M, Latency: 29.04 ms +- 0.19 ms
[Iter 9] Sparsity: 0.45, MACs: 1.26 G, Params: 7.68 M, Latency: 27.47 ms +- 0.28 ms
[Iter 10] Sparsity: 0.50, MACs: 1.07 G, Params: 6.41 M, Latency: 20.68 ms +- 0.13 ms
[Iter 11] Sparsity: 0.55, MACs: 0.85 G, Params: 5.14 M, Latency: 20.48 ms +- 0.21 ms
[Iter 12] Sparsity: 0.60, MACs: 0.67 G, Params: 4.07 M, Latency: 18.12 ms +- 0.15 ms
[Iter 13] Sparsity: 0.65, MACs: 0.53 G, Params: 3.10 M, Latency: 15.19 ms +- 0.01 ms
[Iter 14] Sparsity: 0.70, MACs: 0.39 G, Params: 2.28 M, Latency: 13.47 ms +- 0.01 ms
[Iter 15] Sparsity: 0.75, MACs: 0.29 G, Params: 1.61 M, Latency: 10.07 ms +- 0.01 ms
[Iter 16] Sparsity: 0.80, MACs: 0.18 G, Params: 1.01 M, Latency: 8.96 ms +- 0.02 ms
[Iter 17] Sparsity: 0.85, MACs: 0.10 G, Params: 0.57 M, Latency: 7.03 ms +- 0.04 ms
[Iter 18] Sparsity: 0.90, MACs: 0.05 G, Params: 0.25 M, Latency: 5.81 ms +- 0.03 ms
[Iter 19] Sparsity: 0.95, MACs: 0.01 G, Params: 0.06 M, Latency: 5.70 ms +- 0.03 ms
[Iter 20] Sparsity: 1.00, MACs: 0.01 G, Params: 0.06 M, Latency: 5.71 ms +- 0.03 ms
Please refer to benchmarks for more details.
DepGraph: Towards Any Structural Pruning [Project] [Paper]
Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang
CVPR 2023
LLM-Pruner: On the Structural Pruning of Large Language Models [Project] [arXiv]
Xinyin Ma, Gongfan Fang, Xinchao Wang
NeurIPS 2023
Structural Pruning for Diffusion Models [Project] [arxiv]
Gongfan Fang, Xinyin Ma, Xinchao Wang
NeurIPS 2023
@inproceedings{fang2023depgraph,
title={Depgraph: Towards any structural pruning},
author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={16091--16101},
year={2023}
}