/difftopk

Differentiable Top-k Classification Learning

Primary LanguagePythonMIT LicenseMIT

Differentiable Top-k Classification Learning

difftopk_logo

Official implementation for our ICML 2022 Paper "Differentiable Top-k Classification Learning".

The difftopk library provides different differentiable sorting and ranking methods as well as a wrapper for using them in a TopKCrossEntropyLoss. difftopk builds on PyTorch.

Paper @ ArXiv, Video @ Youtube.

๐Ÿ’ป Installation

difftopk can be installed via pip from PyPI with

pip install difftopk

Sparse Computation

For the functionality of evaluating the differentiable topk operators in a sparse way, the package torch-sparse has to be installed. This can be done, e.g., via

pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.0+cpu.html

For more information on how to install torch-sparse, see here.

Example for Full Installation from Scratch and with all Dependencies

(click to expand)

Depending on your system, the following commands will have to be adjusted.

virtualenv -p python3 .env_topk
. .env_topk/bin/activate
pip install boto3 numpy requests scikit-learn tqdm 
pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
pip install diffsort

# optional for torch-sparse
FORCE_CUDA=1 pip install --no-cache-dir torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.0+cu116.html

pip install .

๐Ÿ‘ฉโ€๐Ÿ’ป Documentation

The difftopk library provides of differentiable sorting and ranking methods as well as a wrapper for using them in a TopKCrossEntropyLoss. The differentiable sorting and ranking methods included are:

  • Variants of Differentiable Sorting Networks
    • bitonic Bitonic Differentiable Sorting Networks (sparse)
    • bitonic__non_sparse Bitonic Differentiable Sorting Networks
    • splitter_selection Differentiable Splitter Selection Networks (sparse)
    • odd_even Odd-Even Differentiable Sorting Networks
  • neuralsort NeuralSort
  • softsort SoftSort

Furthermore, this library also includes the smooth top-k loss from Lapin et al. (SmoothTopKLoss and SmoothHardTopKLoss.)

TopKCrossEntropyLoss

In the center of the library lies the difftopk.TopKCrossEntropyLoss, which may be used as a drop-in replacement for torch.nn.CrossEntropyLoss. The signature of TopKCrossEntropyLoss is defined as follows:

loss_fn = difftopk.TopKCrossEntropyLoss(
    diffsort_method='odd_even',       # the sorting / ranking method as discussed above
    inverse_temperature=2,            # the inverse temperature / steepness
    p_k=[.5, 0., 0., 0., .5],         # the distribution P_K
    n=1000,                           # number of classes
    m=16,                             # the number m of scores to be sorted (can be smaller than n to make it efficient)
    distribution='cauchy',            # the distribution used for differentiable sorting networks
    art_lambda=None,                  # the lambda for the ART used if `distribution='logistic_phi'`
    device='cpu',                     # the device to compute the loss on
    top1_mode='sm'                    # makes training more stable and is the default value
)

It can be used as loss_fn(outputs, labels).

DiffTopkNet

difftopk.DiffTopkNet follows the signature of diffsort.DiffSortNet from the diffsort package. However, instead of returning the full differentiable permutation matrices of size nxn, it returns differentiable top-k attribution matrices of size nxk. More specifically, given an input of shape bxn, the module returns a tuple of None and a Tensor of shape bxnxk. (It returns a tuple to maintain the signature of DiffSortNet.)

sorter = difftopk.DiffTopkNet(
    sorting_network_type='bitonic',
    size=16,                          # Number of inputs
    k=5,                              # k
    sparse=True,                      # whether to use a sparse implementation
    device='cpu',                     # the device
    steepness=10.0,                   # the inverse temperature
    art_lambda=0.25,                  # the lambda for the ART used if `distribution='logistic_phi'`
    distribution='cauchy'             # the distribution used for the differentiable relaxation
)

# Usage example for difftopk on a random input
x = torch.randperm(16).unsqueeze(0).float() * 100.
print(x @ sorter(x)[1][0])

NeuralSort / SoftSort

sorter = difftopk.NeuralSort(
    tau=2.,     # A temperature parameter
)
sorter = difftopk.SoftSort(tau=2.)

๐Ÿงช Experiments

๐Ÿงซ ImageNet Fine-Tuning

We provide pre-computed embeddings for the ImageNet data set. โš ๏ธ These embedding files are very large (>10 GB each.) Feel free to also use the embeddings for other fine-tuning experiments.

# Resnext101 WSL ImageNet-1K (~11GB each)
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x32d_320.p
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x16d_320.p
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x8d_320.p

# Resnext101 WSL ImageNet-21K-P (~50GB)
wget https://publicdata1.nyc3.digitaloceanspaces.com/ImageNet21K-P_embeddings_labels_train_test_IGAM_Resnext101_32x48d_224_float16.p

# Noisy Student EfficientNet-L2 ImageNet-1K (~14GB)
wget https://publicdata1.nyc3.digitaloceanspaces.com/ImageNet_embeddings_labels_train_test_tf_efficientnet_l2_ns_timm_transform_800_float16.p

The following are the hyperparameter combinations for reproducing the tables in the paper. The DiffSortNet entries in Table 5 can be reproduced using

python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p  --nloglr 4.5 \
    --p_k .2 .2 .2 .2 .2  --m 16  --method bitonic  --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5 
python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p  --nloglr 4.5 \
    --p_k .2 .2 .2 .2 .2  --m 16  --method splitter_selection  --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5

python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_tf_efficientnet_l2_ns_timm_transform_800_float16.p  --nloglr 4.5 \
    --p_k .25 .0 .0 .0 .75  --m 16  --method bitonic  --distribution logistic  --inverse_temperature .5 
python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_tf_efficientnet_l2_ns_timm_transform_800_float16.p  --nloglr 4.5 \
    --p_k .25 .0 .0 .0 .75  --m 16  --method splitter_selection  --distribution logistic  --inverse_temperature .5 

and, for the remaining methods and tables, the hyperparameters are defined in the following:

(click to expand)
# Tables 2+3 (1K):
python experiments/train_imagenet.py -d ./ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p  --m 16  --nloglr 4.5

# Tables 2+3 (21K):
python experiments/train_imagenet.py -d ./ImageNet21K-P_embeddings_labels_train_test_IGAM_Resnext101_32x48d_224_float16.p  --m 50  --nloglr 4.  --n_epochs 40 

# combined with one of each of the following

--method softmax_cross_entropy
--method bitonic             --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5
--method splitter_selection  --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5
--method neuralsort          --inverse_temperature 0.5
--method softsort            --inverse_temperature 0.5
--method smooth_hard_topk    --inverse_temperature 1. 

--p_k 1.  0. 0. 0. 0.
--p_k 0.  0. 0. 0. 1.
--p_k .5  0. 0. 0. .5
--p_k .25 0. 0. 0. .75
--p_k .1  0. 0. 0. .9
--p_k .2  .2 .2 .2 .2

๐ŸŽ† CIFAR-100 Training from Scratch

In addition to ImageNet fine-tuning, we can also train a ResNet18 on CIFAR-100 from scratch.

(click to expand)
# Tables 1+4:
python experiments/train_cifar100.py

--method softmax_cross_entropy
--method bitonic             --distribution logistic_phi  --inverse_temperature .5  --art_lambda .5
--method splitter_selection  --distribution logistic_phi  --inverse_temperature .5  --art_lambda .5
--method neuralsort          --inverse_temperature 0.0625
--method softsort            --inverse_temperature 0.0625
--method smooth_hard_topk    --inverse_temperature 1.

--p_k 1.  0. 0. 0. 0.
--p_k 0.  0. 0. 0. 1.
--p_k .5  0. 0. 0. .5
--p_k .25 0. 0. 0. .75
--p_k .1  0. 0. 0. .9
--p_k .2  .2 .2 .2 .2
  
# Examples:
python experiments/train_cifar100.py --method softmax_cross_entropy --p_k 1. 0. 0. 0. 0.
python experiments/train_cifar100.py --method splitter_selection --distribution logistic_phi --inverse_temperature .5 --art_lambda .5 --p_k .2 .2 .2 .2 .2

๐Ÿ“– Citing

@inproceedings{petersen2022difftopk,
  title={{Differentiable Top-k Classification Learning}},
  author={Petersen, Felix and Kuehne, Hilde and Borgelt, Christian and Deussen, Oliver},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2022}
}

๐Ÿ“œ License

difftopk is released under the MIT license. See LICENSE for additional details about it.