/ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.

Primary LanguagePythonOtherNOASSERTION

Token Merging: Your ViT but Faster

Official PyTorch implemention of ToMe from our paper: Token Merging: Your ViT but Faster.
Daniel Bolya, Cheng-Yang Fu, Xiaoliang Dai, Peizhao Zhang, Christoph Feichtenhofer, Judy Hoffman.

What is ToMe?

ToMe Concept Figure

Token Merging (ToMe) allows you to take an existing Vision Transformer architecture and efficiently merge tokens inside of the network for 2-3x faster evaluation (see benchmark script). ToMe is tuned to seamlessly fit inside existing vision transformers, so you can use it without having to do additional training (see eval script). And if you do use ToMe during training, you can reduce the accuracy drop even further while also speeding up training considerably.

What ToMe does

ToMe Visualization

ToMe merges tokens based on their similarity, implicitly grouping parts of objects together. This is in contrast to token pruning, which only removes background tokens. ToMe can get away with reducing more tokens because we can merge redundant foreground tokens in addition to background ones. Visualization of merged tokens on ImageNet-1k val using a trained ViT-H/14 MAE model with ToMe. See this example for how to produce these visualizations. For more, see the paper appendix.

News

  • [2023.02.08] We are delighted to announce that the Meta Research Blog has highlighted our work, Token Merging! Check out the article at Meta Research Blog for more information.
  • [2023.01.31] We are happy to announce that our paper has been accepted for an oral presentation at ICLR 2023.
  • [2023.01.30] We've released checkpoints trained with ToMe for DeiT-Ti, DeiT-S, ViT-B, ViT-L, and ViT-H!
  • [2022.10.18] Initial release.

Installation

See INSTALL.md for installation details.

Usage

This repo does not include training code. Instead, we provide a set of tools to patch existing vision transformer implementations. Then, you can use those implementations out of the box. Currently, we support the following ViT implementations:

See the examples/ directory for a set of usage examples.

Using timm models

Timm is a commonly used implementation for vision transformers in PyTorch. As of version 0.4.12 it currently uses AugReg weights.

import timm, tome

# Load a pretrained model, can be any vit / deit model.
model = timm.create_model("vit_base_patch16_224", pretrained=True)
# Patch the model with ToMe.
tome.patch.timm(model)
# Set the number of tokens reduced per layer. See paper for details.
model.r = 16

Here are some expected results when using the timm implementation off-the-shelf on ImageNet-1k val using a V100:

Model original acc original im/s r ToMe acc ToMe im/s
ViT-S/16 81.41 953 13 79.30 1564
ViT-B/16 84.57 309 13 82.60 511
ViT-L/16 85.82 95 7 84.26 167
ViT-L/16 @ 384 86.92 28 23 86.14 56

See the paper for full results with all models and all values of r.

We've trained some DeiT (v1) models using the official implementation. To use, instantiate a DeiT timm model, patch it with the timm patch (prop_attn=True), and use ImageNet mean and variance for data loading.

Model original acc original im/s r ToMe acc ToMe im/s Checkpoint
DeiT-S/16 79.8 930 13 79.36 1550 deit_S_r13
DeiT-Ti/16 71.8 2558 13 71.27 3980 deit_T_r13

Using SWAG models through Torch Hub

SWAG is a repository of massive weakly-supervised ViT models. They are available from Torch Hub and we include a function to patch its implementation.

import torch, tome

# Load a pretrained model, can be one of ["vit_b16_in1k", "vit_l16_in1k", or "vit_h14_in1k"].
model = torch.hub.load("facebookresearch/swag", model="vit_b16_in1k")
# Patch the model with ToMe.
tome.patch.swag(model)
# Set the amount of reduction. See paper for details.
model.r = 45

Here are some results using these SWAG models off-the-shelf on ImageNet-1k val using a V100:

Model original acc original im/s r ToMe acc ToMe im/s
ViT-B/16 @ 384 85.30 85.7 45 84.59 167.7
ViT-L/16 @ 512 88.06 12.8 40 87.80 26.3
ViT-H/14 @ 518 88.55 4.7 40 88.25 9.8

Full results for other values of r are available in the paper appendix.

Training with MAE

We fine-tune models models pretrained with MAE using the official MAE codebase. Apply the patch as shown in this example and set r as desired (see paper appendix for full list of accuracies vs r). Then, follow the instructions in the MAE code-base to fine tune your model from pretrained weights.

Here are some results after training on ImageNet-1k val using a V100 for evaluation:

Model original acc original im/s r ToMe acc ToMe im/s Checkpoint
ViT-B/16 83.62 309 16 81.91 603 vit_B_16_r16
ViT-L/16 85.66 93 8 85.09 183 vit_L_16_r8
ViT-H/14 86.88 35 7 86.46 63 vit_H_14_r7

To use the checkpoints, apply the MAE patch (tome.patch.mae) to an MAE model from the official MAE codebase as shown in this example. Pass global_pool=True to the vit mae constructors and use ImageNet mean for data loading. For the models we trained (above checkpoints), we used prop_attn=True when patching with ToMe, but leave that as False for off-the-shelf models. Note that the original models in this table were also trained by us.

As a sanity check, here is our baseline result without training using the off-the-shelf ViT-L model available here as described in Table 1 of the paper:

Model original acc original im/s r ToMe acc ToMe im/s
ViT-L/16 85.96 93 8 84.22 183

Applying ToMe to Stable Diffusion

Stable Diffusion is a large open-source text-to-image model that uses a transformer unet backbone. ToMe can be applied to the unet model to increase inference speed and reduce memory usage. An implementation of ToMe for Stable Diffusion is on its way, but for now here are some 4k (3840x2176) wallpapers generated natively in PyTorch on one 4090 GPU with stable diffusion + ToMe in under 2 minutes:

Old World Wallpaper

Cherry Blossoms Wallpaper

See examples for more.

Of course, running stable diffusion at this size is kind of rediculous. We present these results to emphasize the power of ToMe to decrease memory usage and speed up evaluation, even at insane image resolutions. Stable Diffusion was trained on 512px images, so we don't actually recommend natively generating at 4k (you should upscale or inpaint instead).

License and Contributing

Please refer to the CC-BY-NC 4.0. For contributing, see contributing and the code of conduct.

Citation

If you use ToMe or this repository in your work, please cite:

@inproceedings{bolya2022tome,
  title={Token Merging: Your {ViT} but Faster},
  author={Bolya, Daniel and Fu, Cheng-Yang and Dai, Xiaoliang and Zhang, Peizhao and Feichtenhofer, Christoph and Hoffman, Judy},
  booktitle={International Conference on Learning Representations},
  year={2023}
}