/DL2PA_CVPR24

Official repository for the paper DL2PA: Hyperspherical Classification with Dynamic Label-to-Prototype Assignment (CVPR 2024).

Primary LanguagePython

DL2PA: Hyperspherical Classification with Dynamic Label-to-Prototype Assignment

Official repository for the paper DL2PA: Hyperspherical Classification with Dynamic Label-to-Prototype Assignment (CVPR 2024). Arxive

Abstract

Aiming to enhance the utilization of metric space by the parametric softmax classifier, recent studies suggest replacing it with a non-parametric alternative. Although a non-parametric classifier may provide better metric space utilization, it introduces the challenge of capturing inter-class relationships. A shared characteristic among prior non-parametric classifiers is the static assignment of labels to prototypes during the training, i.e., each prototype consistently represents a class throughout the training course. Orthogonal to previous works, we present a simple yet effective method to optimize the category assigned to each prototype (label-to-prototype assignment) during the training. To this aim, we formalize the problem as a two-step optimization objective over network parameters and label-to-prototype assignment mapping. We solve this optimization using a sequential combination of gradient descent and Bipartide matching. We demonstrate the benefits of the proposed approach by conducting experiments on balanced and long-tail classification problems using different backbone network architectures. In particular, our method outperforms its competitors by 1.22% accuracy on CIFAR-100, and 2.15% on ImageNet-200 using a metric space dimension half of the size of its competitors.

Demo Comparison of the proposed method with the conventional PSC and the previous fixed classifier setup, using a toy example with three classes. Each color denotes a distinct class. a) Label-to-prototype assignment remains static during training. In PSC, optimization focuses on the network, consisting of the backbone and prototypes W . In the case of a fixed classifier, only the backbone is optimized, and prototypes remain fixed. b) In the proposed method, prototypes within the hypersphere are fixed, and optimization targets the backbone and the label that each prototype represents. c) Toy example showing changes in label-to-prototype assignment during training.

Table of Contents

Usage

Prototype Estimation

One can generate equidistributed prototypes with desired dimension:

python Prototype_Estimation.py --seed 100 --num_centroids 100 --batch_size 100 --space_dim 50 --num_epoch 1000

Also, you can find the estimated prototype in link

Training classifier

The configs can be found in ./config/Blanced or LongTail/FILENAME.yaml.

python train.py --cfg {path to config}

Balanced Classification Results

Method CIFAR-10 ImageNet-200
d=10 d=25 d=50 d=100 d=25 d=50 d=100 d=200
PSC 25.67 60.0 60.6 62.1 60.0 60.6 62.1 33.1
Word2Vec 29.0 44.5 54.3 57.6 44.5 54.3 57.6 30.0
HPN 51.1 63.0 64.7 65.0 63.0 64.7 65.0 44.7
Ours 57.21 64.63 66.22 62.85 64.63 66.22 62.85 37.28

ImageNet-1K Classification Accuracy (%) when $d=512$:

Method Venue Backbone Optimizer Accuracy (%)
PSC CVPR 2016 ResNet-50 SGD 76.51
DNC ICLR 2022 ResNet-50 SGD 76.49
Goto et al. WACV 2024 ResNet-50 SGD 77.19
Kasarla et al. NeurIPS 2022 ResNet-50 SGD 74.80
Ours CVPR 2024 ResNet-50 SGD 77.47
DNC ICLR 2022 ResNet-101 SGD 77.80
Goto et al. WACV 2024 ResNet-101 SGD 78.27
Kasarla et al. NeurIPS 2022 ResNet-152 SGD 78.50
Ours CVPR 2024 ResNet-101 SGD 79.63
PSC CVPR 2016 Swin-T AdamW 76.91
Ours CVPR 2024 Swin-T AdamW 77.26

Long-tailed Classification Results

Method CIFAR-10 LT (d=64) SVHN LT (d=64) STL-10 LT (d=64)
0.005 0.01 0.02 0.005 0.01 0.02 0.005 0.01 0.02
PSC 67.3 72.8 78.6 40.5 40.9 49.3 33.1 37.9 38.8
ETF 71.9 76.5 81.0 42.8 45.7 49.8 33.5 37.2 37.9
Ours 71.5 76.9 81.4 40.9 47.0 49.7 35.7 35.6 38.0

CIFAR-100 LT Classification Accuracy (%):

Method d 0.005 0.01 0.02
PSC 128 38.7 43.0 48.1
ETF 128 40.9 45.3 50.4
Ours 128 41.3 44.9 50.7

Citation

@misc{saadabadi2024hyperspherical,
    title={Hyperspherical Classification with Dynamic Label-to-Prototype Assignment},
    author={Mohammad Saeed Ebrahimi Saadabadi and Ali Dabouei and Sahar Rahimi Malakshan and Nasser M. Nasrabad},
    year={2024},
    eprint={2403.16937},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Acknowledgments

Here are some great resources we benefit from:

Contact

If there is a question regarding any part of the code, or it needs further clarification, please create an issue or send me an email: me00018@mix.wvu.edu.