/TaT

Primary LanguagePythonBSD 2-Clause "Simplified" LicenseBSD-2-Clause

Knowledge Distillation via the Target-aware Transformer (CVPR2022)

Codebase of our TaT on ImageNet. Refer to TaT-seg for the experiments on semantic segmentation.

Overview

Executable code can be found in examples/image_classification.py. The implementation of TaT is AttnEmbed. The loss function MaskedFM is decoupled with the model.

Note

  1. This codebase currently do not support resume. However, it allows you to load a pre-trained model for specific purposes, i.e., distilling a contrastive learning model.
  2. The classification model is wrapped with the learnable KD parameters. Please be careful on the model parameters you want to save.

Customization

If you would like to customize your own model, please put all the learnable parameters on here. And you can set up the calculation of the loss funcion on here.

We use the Forward Hook to extract the intermediate representations. Just modify the yaml file to access the model layers of your interest. This example notebook will give you a better idea of the usage. You may refer to our config.

Examples

Requirments

  • Python 3.7
  • pytorch 1.5
  • einops
  • ml-collection

Before getting started

Please modify the ImageNet path of the config.

We use 8 GPUs with 256 images per GPU.

Training

sh ./train_local.sh

Testing

sh ./test_local.sh

Issues / Contact

Feel free to create an issue if you get a question or just email me ( sihao.lin@student.rmit.edu.au ).

Acknowledgement

This repo is built upon torchdistill. Thanks to Yoshitomo.