/ltmp

Code for Learned Thresholds Token Merging and Pruning for Vision Transformers (LTMP). A technique to reduce the size of Vision Transformers to any desired size with minimal loss of accuracy.

Primary LanguagePython

Learned Thresholds Token Merging and Pruning for Vision Transformers

By Maxim Bonnaerens, and Joni Dambre.

Abstract

Vision transformers have demonstrated remarkable success in a wide range of computer vision tasks over the last years. However, their high computational costs remain a significant barrier to their practical deployment. In particular, the complexity of transformer models is quadratic with respect to the number of input tokens. Therefore techniques that reduce the number of input tokens that need to be processed have been proposed. This paper introduces Learned Thresholds token Merging and Pruning (LTMP), a novel approach that leverages the strengths of both token merging and token pruning. LTMP uses learned threshold masking modules that dynamically determine which tokens to merge and which to prune. We demonstrate our approach with extensive experiments on vision transformers on the ImageNet classification task. Our results demonstrate that LTMP achieves state-of-the-art accuracy across reduction rates while requiring only a single fine-tuning epoch, which is an order of magnitude faster than previous methods.

TL;DR

token pruning + token merging = token merging and pruning

Overview

An overview of our framework is shown below. Given any vision transformer, our approach adds merging (LTM) and pruning (LTP) components with learned threshold masking modules in each transformer block between the Multi-head Self-Attention (MSA) and MLP components. Based on the attention in the MSA, importance scores for each token and similarity scores between tokens are computed. Learned threshold masking modules then learn the thresholds that decide which tokens to prune and which ones to merge.

framework overview

Installation

pip install -e .

Usage

This repository is based on the vision transformers of 🤗timm (v0.9.5).

Training

LTMP Vision Transformers for training can be used as follows:

import timm
import ltmp

model = timm.create_model("ltmp_vit_base_patch16_224", pretrained=True, tau=0.1, **kwargs)

To reproduce the results from the paper run:

python tools/train.py /path/to/imagenet/ --model ltmp_deit_small_patch16_224 --pretrained -b 128 --lr 0.000005 0.005 --reduction-target 0.75

Inference

"ltmp_{vit_model}" models obtained through the training detailed above can be used for inference by using the following variant which actually prunes and merges tokens.

import timm
import ltmp

model = timm.create_model("inference_ltmp_vit_base_patch16_224", pretrained=True, tau=0.1, **kwargs)

To check the accuracy of trained models:

python tools/validate_timm.py /path/to/imagenet/ --model ltmp_deit_small_patch16_224 --checkpoint /path/to/checkpoint.pth.tar -b 1

Adapt to other vision transformers

See ./ltmp/timm/lt_mergeprune.py and ./ltmp/timm/lt_mergeprune_inference.py for the changes required to adopt LTMP in a vision transformer. ./tools/train.py contains the code to train LTMP models.

Citation

If you find this work useful, consider citing it:

@article{
    bonnaerens2023learned,
    title={Learned Thresholds Token Merging and Pruning for Vision Transformers},
    author={Maxim Bonnaerens and Joni Dambre},
    journal={Transactions on Machine Learning Research},
    issn={2835-8856},
    year={2023},
    url={https://openreview.net/forum?id=WYKTCKpImz},
    note={}
}