/KS-DETR

Primary LanguagePythonApache License 2.0Apache-2.0

KS-DETR: Knowledge Sharing in Attention Learning for Detection Transformer

We release our code for our submitted manuscript KS-DETR: Knowledge Sharing in Attention Learning for Detection Transformer.


Main results and Pretrained Models

Here we provide the pretrained KS-DETR weights based on detrex.

Name Backbone Pretrain Epochs box
AP
download
KS-DAB-DETR-R50 R-50 IN1k 50 43.9 model
KS-DAB-DETR-R101 R-101 IN1k 50 45.3 model
KS-DAB-DETR-Swin-T Swin-T IN1k 50 47.1 model
KS-Conditional-DETR-R50 R-50 IN1k 50 45.3 model
KS-Conditional-DETR-R101 R-101 IN1k 50 47.1 model
KS-DN-DETR-R50 R-50 IN1k 50 45.2 model
KS-DN-DETR-R101 R-101 IN1k 50 46.5 model
KS-Deformable-DETR-R50 R-50 IN1k 12 36.4 model
KS-Deformable-DETR-R101 R-101 IN1k 12 38.4 model
KS-DN-Deformable-DETR-R50 R-50 IN1k 12 46.5 model
KS-Deformable-DETR-R50 R-50 IN1k 50 44.8 model
KS-Deformable-DETR-R101 R-101 IN1k 50 46.0 model

Installation

conda create -n ksdetr python=3.8 -y
conda activate ksdetr

git clone https://github.com/edocanonymous/KS-DETR
cd KS-DETR
python -m pip install -e detectron2
pip install -e .

Training

To train the models with R101 backbone, the pretrained IN1k weights should be available at location output/weights/R-101.pkl. We can follow https://github.com/facebookresearch/detectron2/blob/main/tools/convert-torchvision-to-d2.py to convert https://download.pytorch.org/models/resnet101-5d3b4d8f.pth to torchvision format and obtain R-101.pkl by

 wget https://download.pytorch.org/models/resnet101-5d3b4d8f.pth -O output/r101.pth
 python ./detectron2/tools/convert-torchvision-to-d2.py output/r101.pth output/weights/R-101.pkl

We provide our converted R-101.pkl file here.

All configs can be trained with:

cd detrex
python tools/train_net.py --config-file projects/dab_detr/configs/path/to/config.py --num-gpus 8

To train KS-DAB-DETR-R50, KS-DAB-DETR-R101, and KS-DAB-DETR-Swin-T,

python tools/train_net.py --config-file projects/ks_detr/configs/ks_dab_detr/ks_dab_detr_r50_50ep_smlp_qkv_triple_attn.py --num-gpus 8

python tools/train_net.py --config-file projects/ks_detr/configs/ks_dab_detr/ks_dab_detr_r101_50ep_smlp_qkv_triple_attn.py --num-gpus 8

python tools/train_net.py --config-file projects/ks_detr/configs/ks_dab_detr/ks_dab_detr_swin_tiny_50ep_smlp_qkv_triple_attn.py --num-gpus 8

Evaluation

Model evaluation can be done as follows:

cd detrex
python tools/train_net.py --config-file projects/dab_detr/configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint

License

This project is released under the Apache 2.0 license.

Acknowledgement

  • Our code is built on detrex, which is an open-source toolbox for Transformer-based detection algorithms created by researchers of IDEACVR.

  • detrex is built based on Detectron2 and part of its module design is borrowed from MMDetection, DETR, and Deformable-DETR.