Pytorch Library for L0 based pruning, as proposed in the paper: Structured Pruning of Large Language Models
pip install -U flop
Create a hard concrete mask of size N:
from flop import HardConrete
N = 100
hardconcrete = HardConcrete(n_in=N)
You can then sample masks on the fly with:
mask = hardconcrete()
Note that during evaluation, a mask is compiled and fixed.
You may also find these other objects useful:
ProjectedLinear
: replaces a linear layer to include an intermediate projection.HardConreteProjectedLinear
: the hard conrete version of theProjectedLinear
module.
You may instantiate the HardConcrete objects directly, or you can choose to first train with
a ProjectedLinear
module, and introduce the hardconcrete mask with:
module = ProjectedLinear(...)
# Perform training
# ...
# Start pruning
pruning_module = HardConcreteProjectedLinear.from_module(module)
We also provide some utily functions to replace all ProjectedLinear modules in a model:
from flop import make_hard_concrete
model = make_hard_concrete(model)
If you are using Flambe to train your models, then you could simply use the provided
HardConcreteTrainer
, which will apply the make_hard_concrete
method on your input
module, and perform training.
To replicate the SRU numbers, please look at the script examples/train_enwik8.py
.
@article{wang2019structured,
title={Structured Pruning of Large Language Models},
author={Wang, Ziheng and Wohlwend, Jeremy and Lei, Tao},
journal={arXiv preprint arXiv:1910.04732},
year={2019}
}