/self-classifier

PyTorch implementation of "Self-Supervised Classification Network" from ECCV 2022

Primary LanguagePythonBSD 3-Clause "New" or "Revised" LicenseBSD-3-Clause

Self-Classifier: Self-Supervised Classification Network

Official PyTorch implementation and pretrained models of the paper Self-Supervised Classification Network from ECCV 2022.

Self-Classifier architecture. Two augmented views of the same image are processed by a shared network comprised of a backbone (e.g. CNN) and a classifier (e.g. projection MLP + linear classification head). The cross-entropy of the two views is minimized to promote same class prediction while avoiding degenerate solutions by asserting a uniform prior on class predictions. The resulting model learns representations and discovers the underlying classes in a single-stage end-to-end unsupervised manner.

If you find this repository useful in your research, please cite:

@article{amrani2021self,
  title={Self-Supervised Classification Network},
  author={Amrani, Elad and Karlinsky, Leonid and Bronstein, Alex},
  journal={arXiv preprint arXiv:2103.10994},
  year={2021}
}

Pretrained Models

Download pretrained models here.

Setup

  1. Install Conda environment:

     conda env create -f ./environment.yml
    
  2. Install Apex with CUDA extension:

     export TORCH_CUDA_ARCH_LIST="7.0"  # see https://en.wikipedia.org/wiki/CUDA#GPUs_supported
     pip install git+git://github.com/NVIDIA/apex.git@4a1aa97e31ca87514e17c3cd3bbc03f4204579d0 --install-option="--cuda_ext"         
    

Training & Evaluation

Distributed training & evaluation is available via Slurm. See SBATCH scripts here.

IMPORTANT: set DATASET_PATH, EXPERIMENT_PATH and PRETRAINED_PATH to match your local paths.

Training

method epochs NMI AMI ARI ACC linear probing top-1 acc. training script
Self-Classifier 100 71.2 49.2 26.1 37.3 72.4 script
Self-Classifier 200 72.5 51.6 28.1 39.4 73.5 script
Self-Classifier 400 72.9 52.3 28.8 40.2 74.2 script
Self-Classifier 800 73.3 53.1 29.5 41.1 74.1 script

NMI: Normalized Mutual Information, AMI: Adjusted Normalized Mutual Information, ARI: Adjusted Rand-Index and ACC: Unsupervised clustering accuracy. linear probing: training a supervised linear classifier on top of frozen self-supervised features.

Evaluation

Unsupervised Image Classification

dataset NMI AMI ARI ACC evaluation script
ImageNet 1K classes 73.3 53.1 29.5 41.1 script
ImageNet 10 superclasses (level #2 in hierarchy) 74.0 54.3 30.9 85.7 script
ImageNet 29 superclasses (level #3 in hierarchy) 74.0 54.3 30.9 79.7 script
ImageNet 128 superclasses (level #4 in hierarchy) 74.0 54.3 30.9 71.8 script
ImageNet 466 superclasses (level #5 in hierarchy) 73.9 54.3 30.8 60.0 script
ImageNet 591 superclasses (level #6 in hierarchy) 74.1 55.3 32.1 46.7 script
BREEDS Entity13 (ImageNet based) 73.6 54.1 30.7 84.4 script
BREEDS Entity30 (ImageNet based) 72.9 53.4 29.8 81.0 script
BREEDS Living17 (ImageNet based) 67.2 51.8 26.4 90.8 script
BREEDS Nonliving26 (ImageNet based) 72.2 57.0 36.8 76.7 script

NMI: Normalized Mutual Information, AMI: Adjusted Normalized Mutual Information, ARI: Adjusted Rand-Index and ACC: Unsupervised clustering accuracy.

K-Means Baselines Using Self-Supervised Pretrained Models

method NMI AMI ARI ACC evaluation script
BarlowTwins 68.8 48.3 14.7 33.2 script
OBoW 66.5 42.0 16.9 31.1 script
DINO 66.2 42.3 15.6 30.7 script
MoCov2 66.6 45.3 12.0 30.6 script
SwAV 64.1 38.8 13.4 28.1 script
SimSiam 62.2 34.9 11.6 24.9 script

NMI: Normalized Mutual Information, AMI: Adjusted Normalized Mutual Information, ARI: Adjusted Rand-Index and ACC: Unsupervised clustering accuracy. All methods are evaluated on ImageNet 1K classes with original pre-trained models - MoCov2, OBoW, SimSiam, SwAV. DINO and BarlowTwins use PyTorch Hub (i.e., no need for direct download).

Image Classification with Linear Models

For training a supervised linear classifier on a frozen backbone, run:

    sbatch ./scripts/lincls_eval.sh

Image Classification with kNN

For running K-nearest neighbor classifier on ImageNet validation set, run:

    sbatch ./scripts/knn_eval.sh

Transferring to Object Detection and Instance Segmentation

See ./detection.

Ablation study

For training the 100-epoch ablation study baseline, run:

    sbatch ./scripts/ablation/train_100ep.sh

For training any of the ablation study runs presented in the paper, run:

    sbatch ./scripts/ablation/<ablation_name>/<ablation_script>.sh

Qualitative Examples (classes predicted by Self-Classifier on ImageNet validation set)

High accuracy classes predicted by Self-Classifier on ImageNet validation set. Images are sampled randomly from each predicted class. Note that the predicted classes capture a large variety of different backgrounds and viewpoints.

To reproduce qualitative examples, run:

    sbatch ./scripts/cls_eval.sh

License

See the LICENSE file for more details.