/TokenReduction

Official PyTorch implementation of Which Tokens to Use? Investigating Token Reduction in Vision Transformers presented at ICCV 2023 NIVT workshop

Primary LanguagePythonMIT LicenseMIT

Which Tokens to Use? Investigating Token Reduction in Vision Transformers

This repository is the official implementation of Which Tokens to Use? Investigating Token Reduction in Vision Transformers.

We conduct the first systematic comparison and analysis of 10 state-of-the-art token reduction methods across four image classification datasets, trained using a single codebase and consistent training protocol. Through extensive experiments providing deeper insights into the core mechanisms of the token reduction methods. We make all of the training and analysis code used for the project available.

The project page can be found here.

Requirements

Requirements can be found in the requirements.txt file.

Datasets

We test on the ImageNet-1K, NABirds, COCO 2014, and NUS-Wide datasets. They are available through the following links:

Training

The models are trained by finetuning a pretrained DeiT model. Reduction blocks are inserted using the --reduction_loc argument. In our paper we focus on reducing at the 4th, 7th, and 10th block (note that the argument is 0 indexed), as commonly done in the litterature. In order to work with a larger effective batch size than what can be fitted on the GPU VRAM, we use the --grad_accum_steps argument to aggregate over batches.

An example training command is shown below, training on ImageNet with the Top-K reduction method and a keep rate of 0.9. Note that for merging based methods the keep-rate argument denotes the number of clusters, while for the ATS method the keep rate is the maximum number of tokens kept.

train.py --dataset imagenet --data <path_to_data> --batch-size 256 --lr 0.001 --epochs 30 --warmup-epochs 20 --lr_batch_normalizer 1024 --sched_in_steps --use_amp --grad_accum_steps 2 --wandb_project <wandb_project> --wandb_group <wandb_group> --output_dir <path_to_model_output> --model topk_small_patch16_224 --reduction_loc 3 6 9 --keep_rate 0.9

Evaluation

In order to evaluate the trained models, we first extract an overview of the Weights and Biases (W&B) project where we have logged the training runs.

get_wandb_tables.py --entity <wandb_entity> --project <wandb_project> --output_path <path_to_wandb_overviews>

Using the W&B overview, we extract the metric performance as well as reduction patterns of all models in the W&B project if the viz_mode argument is provided.

validate_dirs.py --dataset imagenet --data <path_to_data> --output_dir <path_to_eval_output> --dataset_csv <path_to_wandb_overviews> --viz_mode --use_amp

Analysis

In our analysis we investigate how reduction patterns and CLS Features compares across trained models.

Comparing CLS Features and per-image Reduction Patterns

The reduction pattern analysis is conducted using the scripts in the reduction_methods_analysis directory. CLS features and reduction patterns are compare within models across keep rates and model capacities using the compare_{cls_features, merging, pruning}_rates.py and compare_{cls_features, merging, pruning}_capacity.py scripts, respectively. In order to compare across reduction methods we use the compare_{cls_features, merging, pruning}_models.py scripts.

The scripts take 4 common arguments:

  • parrent_dir: The path to where the extracted reduction patterns (using validate.py) are stored
  • dataset_csv: The path to the extracted W&B overview csv file
  • output_file: Name of the output file
  • output_path: Path to the where the output_file will be saved

The *_models.py scripts also take a capacity argument to indicate which model capacities to compare.

In order to compare CLS features, we need to extract them using the extract_cls_features_dirs.py script, which shares the same arguments as the validate_dirs.py script.

Extracting and Comparing Global Pruning-based Reduction Patterns

In order to compare reduction patterns across datasets we first determine the average depth of each token. This is extracted using the compute_token_statistics.py script, which shares arguments with the previous compare scripts.

Using the extracted token statistics the global pruning-based reduction patterns can be compared using the compare_heatmaps.py script. Per default the script simply compares the same reduction method at the same capacity and keep rate across datasets. This cna be modified using hte following arguments:

  • compare_within_dataset
  • compare_across_rates
  • compare_across_capacities
  • compare_across_models

Collating data and calculating correlations

The output data of the compare scripts are collated using the corresponding collate_{capacity, models, rates}_data.py scripts. The scripts assumes the cls_features, merging, and pruning output files from the compare scripts follow a specific naming scheme. See the scripts for details.

The collated files are then used to compute the Spearman's and Kendall's correlations coefficients between the difference in metric performance and the different reduction pattern and CLS feature similaritiy metrics. This is done using the calculate_correlation_{capacity, models, rates}.py scripts.

Code Credits

The token reduction method code is based and inspired by:

Parts of the training code and large part of the ViT implementation is based and inspired by:

Parts of the analysis code is based and inspiredby:

License

The Code is licensed under an MIT License, with exceptions of the afforementioned code credits which follows the license of the original authors.

Bibtex

@InProceedings{Haurum_2023_ICCVW,
author = {Joakim Bruslund Haurum and Sergio Escalera and Graham W. Taylor and Thomas B. Moeslund},
title = {Which Tokens to Use? Investigating Token Reduction in Vision Transformers},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
month = {October},
year = {2023},
}