/SSLight

[ICLR'23] Effective Self-supervised Pre-training on Low-compute networks without Distillation

Primary LanguagePythonApache License 2.0Apache-2.0

Fuwen Tan, Fatemeh Saleh, Brais Martinez, ICLR 2023.

Abstract

Despite the impressive progress of self-supervised learning (SSL), its applicability to low-compute networks has received limited attention. Reported performance has trailed behind standard supervised pre-training by a large margin, barring self-supervised learning from making an impact on models that are deployed on device. Most prior works attribute this poor performance to the capacity bottleneck of the low-compute networks and opt to bypass the problem through the use of knowledge distillation (KD). In this work, we revisit SSL for efficient neural networks, taking a closer look at what are the detrimental factors causing the practical limitations, and whether they are intrinsic to the self-supervised low-compute setting. We find that, contrary to accepted knowledge, there is no intrinsic architectural bottleneck, we diagnose that the performance bottleneck is related to the model complexity vs regularization strength trade-off. In particular, we start by empirically observing that the use of local views can have a dramatic impact on the effectiveness of the SSL methods. This hints at view sampling being one of the performance bottlenecks for SSL on low-capacity networks. We hypothesize that the view sampling strategy for large neural networks, which requires matching views in very diverse spatial scales and contexts, is too demanding for low-capacity architectures. We systematize the design of the view sampling mechanism, leading to a new training methodology that consistently improves the performance across different SSL methods (e.g. MoCo-v2, SwAV or DINO), different low-size networks (convolution-based networks, e.g. MobileNetV2, ResNet18, ResNet34 and vision transformer, e.g. ViT-Ti), and different tasks (linear probe, object detection, instance segmentation and semi-supervised learning). Our best models establish new state-of-the-art for SSL methods on low-compute networks despite not using a KD loss term.

Software required

The code is only tested on Linux 64:

  cd $(SSLIGHT_ROOT)/src
  conda env create -f environment.yml
  conda activate ssl

Experiments

This repo supports pre-training [DINO|SwAV|MoCo] with [MobileNet V2|ResNets|ViTs] as the baseline. To run the training:

  cd $(SSLIGHT_ROOT)/src
  python3 main.py --cfg config/exp_yamls/dino/dino_cnn_sslight.yaml DATA.PATH_TO_DATA_DIR $IN1K_PATH OUTPUT_DIR $OUTPUT_PATH

To assess the quality of features during pre-training, an additional linear classifier can be trained on the separated features. This ensures that the gradient from the linear classifier does not interfere with the feature learning process:

  python3 main.py --cfg config/exp_yamls/dino/dino_cnn_sslight.yaml DATA.PATH_TO_DATA_DIR $IN1K_PATH OUTPUT_DIR $OUTPUT_PATH TRAIN.JOINT_LINEAR_PROBE True

Note that the accuracy of this extra classifier is typically lower than a standard linear probing evaluation.

Pretraining

The table below includes the scripts for the pre-training experiments:

Model Backbone Pre-training IN1K
Linear eval
Pretrained ckpts
(re-trained)
DINO baseline ViT-Tiny/16 script Accu.: 66.7
script
ckpt / log
DINO SSLight ViT-Tiny/16 script Accu.: 69.5 (+2.8)
script
ckpt / log
DINO baseline ResNet18 script Accu.: 62.2
script
ckpt / log
DINO SSLight ResNet18 script Accu.: 65.7 (+3.5)
script
ckpt / log
DINO baseline ResNet34 script Accu.: 67.7
script
ckpt / log
DINO SSLight ResNet34 script Accu.: 69.7 (+2.0)
script
ckpt / log
DINO baseline MobileNet V2 script Accu.: 66.2
script
ckpt / log
DINO SSLight MobileNet V2 script Accu.: 68.3 (+2.1)
script
ckpt / log
SWAV baseline MobileNet V2 script Accu.: 65.2
script
ckpt / log
SWAV SSLight MobileNet V2 script Accu.: 67.3 (+2.1)
script
ckpt / log
MoCo baseline MobileNet V2 script Accu.: 60.6
script
ckpt / log
MoCo SSLight MobileNet V2 script Accu.: 61.6 (+1.0)
script
ckpt / log

Downstream evaluations

The table below includes the scripts for semi-supervised, object detection and instance segmentation evaluations

Model Backbone IN1K
Semi-sup 1%
IN1K
Semi-sup 10%
CoCo
Object Detection
CoCo
Instance Segmentation
DINO baseline ResNet18 Accu.: 44.5
script
Accu.: 59.2
script
AP: 32.7
script
AP: 30.6
script
DINO SSLight ResNet18 Accu.: 49.8 (+5.3)
script
Accu.: 63.0 (+3.8)
script
AP: 34.1 (+1.4)
script
AP: 31.8 (+1.2)
script
DINO baseline ResNet34 Accu.: 52.4
script
Accu.: 65.4
script
AP: 37.6
script
AP: 34.6
script
DINO SSLight ResNet34 Accu.: 55.2 (+2.8)
script
Accu.: 67.2 (+1.8)
script
AP: 38.6 (+1.0):
script
AP: 35.5 (+0.9):
script
DINO baseline MobileNet V2 Accu.: 47.9
script
Accu.: 61.3
script
AP: 30.9
script
AP: 28.1
script
DINO SSLight MobileNet V2 Accu.: 50.6 (+2.7)
script
Accu.: 63.5 (+2.2)
script
AP: 32.1 (+1.2)
script
AP: 29.1 (+1.0)
script

Citing

If you find our paper/code useful, please consider citing:

@inproceedings{sslight2023,
    author = {Fuwen Tan and Fatemeh Saleh and Brais Martinez},
    title = {Effective Self-supervised Pre-training on Low-compute Networks without Distillation},
    booktitle = {International Conference on Learning Representations (ICLR)},
    year = {2023},
 }