/short-transformers

Prune transformer layers

Primary LanguagePythonMIT LicenseMIT

✂️ Short Transformers

Normalized angular distance from initial layer l (x-axis) with block size n (y-axis).

pypi Version Ruff

Installation:

pip install short-transformers

Required additional dependencies: torch, transformers, datasets, accelerate.

Quickstart:

from short_transformers import ShortTransformer
from datasets import load_dataset

# load from path/hf_hub
model = ShortTransformer.from_pretrained(model_name)

# or use hf model
# model = ShortTransformer.from_model(hf_model)

# load hf dataset
dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

# remove 5 layers, use the dataset to find the least important layers to remove
short_model = model.remove_layers(block_size=5, dataset=dataset, limit=1000)

# continue training to heal after the cut
# ...

# save as hf model
short_model.save_pretrained(output_path)

Both short_model and the saved model are fully compatible with transformers. See examples/basic.py for a complete working example.

Pruning in steps:

Pruning can composed step-by-step and customized:

  1. Analyze model layers:
from datasets import load_dataset
from short_transformers import ShortTransformer
from short_transformers.utils import (
    draw_diagram,
    get_scored_blocks,
    get_best_pruning_start,
)
# load from path/hf_hub
model_name = "meta-llama/Meta-Llama-3-8B"

model = ShortTransformer.from_pretrained(model_name, device_map="auto")

dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

# calculate distances between inputs/outputs from/to model layers
# results in a triangular numpy array of shape (layer_count, layer_count)
# results[x, y] - averaged distances for block of size x starting at layer y
results = model.analyse_layers(
    dataset=dataset,
    tokenizer=tokenizer,
    use_chat_template=False,
    key="text",
    limit=100,
    max_length=1000,
)

# draw results
# diagrams style matches the style of original article
# "The Unreasonable Ineffectiveness of the Deeper Layers"
draw_diagram(results, "results.png", title="Meta-Llama-3-8B", normalized=True)

Example output:

  1. Find optimal block_size and start_layer:
# find optimial block of size 'block_size' to prune
start_layer = get_best_pruning_start(results, block_size=5)

# evaluate all possibe block sizes to prune,
# for each block returns score 0-1
# which is averaged over samples distance between input and output to/from a block
block_score = get_scored_blocks(results, return_md=True, threshold=0.3)

Example output:

Block_size Removed_layers Score (avg dist)
1 25-25 0.123
2 24-25 0.155
3 25-27 0.181
4 24-27 0.204
5 23-27 0.226
6 22-27 0.248
7 22-28 0.268
8 20-27 0.291
  1. Pruning layers:
# prune 5-layers block
model.prune(start_layer=start_layer, block_size=5)

# save the pruned model
model.save_pretrained("model_output_dir")

See example/prune_in_steps.py for a complete working example.

  1. Changing the pruning method:

Default pruning method is based on angular distance of the last token. It is possible to overwrite the distance by using model.set_metric(some_callable) before model.analyse_layers().

# ...
from short_transformers.dist import get_angular_distance_ith_token

model_name = "meta-llama/Meta-Llama-3-8B"
model = ShortTransformer.from_pretrained(model_name, device_map="auto")

# choose metric
# calculate distances based on the angular distance of the i=0 token
model.set_metric(get_angular_distance_ith_token(i=0))

# load dataset ...

results = model.analyse_layers(
    dataset=dataset,
    tokenizer=tokenizer,
    key="text",
    limit=1,
    max_length=1000,
)

Supported metric for layer importance calculation:

Example outputs:

Meta-Llama-3-8B-Instruct

Layerwise distances:

Blockwise distances:

Euclidian Dist Last Token

Figure 1: Euclidian Dist Last Token. Figure 2: Euclidian Dist Last Token Normalised

Relative Magnitude

Figure 1: Relative Magnitude. Figure 2: Relative Magnitude Normalised

Bi Score

Figure 1: Bi Score. Figure 2: Bi Score Normalised

Linear Approximation Last Token

Figure 1: Linear Approximation Last Token. Figure 2: Linear Approximation Last Token Normalised

Angular Distance All Tokens

Figure 1: Angular Distance All Tokens. Figure 2: Angular Distance All Tokens Normalised

Angular Distance Last Token

Figure 1: Angular Distance Last Token. Figure 2: Angular Distance Last Token Normalised

Yi-1.5-9B-Chat-16K

Layerwise distances:

Blockwise distances:

Euclidian Dist Last Token

Figure 1: Euclidian Dist Last Token. Figure 2: Euclidian Dist Last Token Normalised

Relative Magnitude

Figure 1: Relative Magnitude. Figure 2: Relative Magnitude Normalised

Bi Score

Figure 1: Bi Score. Figure 2: Bi Score Normalised

Linear Approximation Last Token

Figure 1: Linear Approximation Last Token. Figure 2: Linear Approximation Last Token Normalised

Angular Distance Last Token

Figure 1: Angular Distance Last Token. Figure 2: Angular Distance Last Token Normalised

Citing:

If you use Short Transformers in your research, please cite with the following BibText

@misc{russak2024shorttransformers,
    title  = {ShortTransformers, optimal layer pruning tools},
    author = {Melisa Russak},
    url    = {https://github.com/melisa/short-transformers},
    year   = {2024}
}
@misc{gromov2024unreasonable,
      title={The Unreasonable Ineffectiveness of the Deeper Layers}, 
      author={Andrey Gromov and Kushal Tirumala and Hassan Shapourian and Paolo Glorioso and Daniel A. Roberts},
      year={2024},
      eprint={2403.17887},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{razzhigaev2024transformer,
      title={Your Transformer is Secretly Linear}, 
      author={Anton Razzhigaev and Matvey Mikhalchuk and Elizaveta Goncharova and Nikolai Gerasimenko and Ivan Oseledets and Denis Dimitrov and Andrey Kuznetsov},
      year={2024},
      eprint={2405.12250},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{men2024shortgpt,
      title={ShortGPT: Layers in Large Language Models are More Redundant Than You Expect}, 
      author={Xin Men and Mingyu Xu and Qingyu Zhang and Bingning Wang and Hongyu Lin and Yaojie Lu and Xianpei Han and Weipeng Chen},
      year={2024},
      eprint={2403.03853},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{samragh2023weight,
      title={Weight subcloning: direct initialization of transformers using larger pretrained ones}, 
      author={Mohammad Samragh and Mehrdad Farajtabar and Sachin Mehta and Raviteja Vemulapalli and Fartash Faghri and Devang Naik and Oncel Tuzel and Mohammad Rastegari},
      year={2023},
      eprint={2312.09299},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}