This repository houses a collection of all self-supervised learning models.

I implemented most of the current state-of-the-art self-supervised learning methods including SimCLRv2, BYOL, SimSiam, MoCov2, and SwAV.


  • the supervised approach to learning features from labeled data has reached its saturation due to intense labor required in manually annotating millions of data samples
  • data itself provides supervision
  • basic intuition behind contrastive learning paradigm: push original and augmented images closer and push original and negative images away



Implementation Details

  • ResNet-50 widely used because of its balance between size and learning capability
  • output of the res5 (residual block) features are average-pooled to get a 2048-dimensional vector
  • apply a single linear projection to get a 128-dimensional feature vector
  • or a shallow MLP (1 hidden layer) that maps representations to a latent space where contrastive loss is applied
  • later stages of the encoder prove to be a better representation of the input
  • cosine similarity (L2 normalization) of two variables (vectors) is the cosine of the angle between them
  • InfoNCE (Noise Contrastive Estimation) loss where q is the original sample
  • perform a non-linear logistic regression that discriminates between observed data and some artifically generated noise


  • training involves learning the parameters of encoder network by minimizing the loss function


  • models are used as frozen encoders for a linear classifier (linear evaluation protocol)


Install Package

pip install ConSSL


import torch
from ConSSL.self_supervised import SimSiam
from ConSSL.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDatatTransform
from torchvision import models

train_dataset = MyDataset(transform=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())

# train from scratch
model = SimSiam()
trainer = Trainer(gpu=4)

Notes On Implementation

  • I found that using SimCLR augmentation directly will sometimes cause the model to collpase. This maybe due to the fact that SimCLR augmentation is too strong.
  • Adopting the MoCo augmentation during the warmup stage helps.
  • Gradient check for Batch-Optimization: Gradient descent over a batch of samples can not only benefit the optimization but also leverages data parallelism. However, you have to be careful not to mix data across the batch dimension. Only a small error in a reshape or permutation operation results in the optimization getting stuck and you won't even get a runtime eror. You should check the operations that reshape and permute tensor dimensions in your model.
  • run the model on an example batch (can be random data)
  • get the output batch and select the n-th sample (choose n)
  • compute a dummy loss value of only that sample and compute the gradient w.r.t. the entire input batch
  • observce that only the i-th sample in the input batch has non-zero gradient
from pytorch_lightning import Trainer
from ConSSL.callbacks import BatchGradientVerificationCallback

model = YourLightningModule()
verification = BatchGradientVerificationCallBack()
trainer = Trainer(callbacks=[verification])
  • this is how you should predict based on ConSSL models in your own data
# trained without labels 
from ConSSL.models.self_supervised import SimCLR
weight_path = 'path/to/your/checkpoint/file'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
resnet50_unsupervsied = simclr.encoder.eval()

# trained with labels 
from torchvision.models import resnet50
resnet50_supervised = resnet50(pretrained=True)

x = image_sample()
unsup_feats = resnet50_unsupervsied(x)
sup_feats = resnet50_supervised(x)


Collection of useful datasets including STL10, MNIST, CIFAR10, CIFAR100, ImageNet.

from ConSSl.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule
# ImagenetDataModule assumes you have ILSVRC2012 imagenet data downloadedd. It validates the data using meta.bin.

# datamodules for debugging
from ConSSL.datasets import DummyDataset
from torch.utils.data import DataLoader 

# mnist dims
ds = DummyDataset((1,28,28), (1,))
dl = DataLoader(ds, batch_size=256)
# get first batch
batch = next(iter(dl))
x, y = batch
x.size() # torch.Size([256, 1, 28, 28])
y.size() # torch.Size([256,1])

# standard transforms is defined as follows:
mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])

dm = CIFAR10DataModule(PATH)
dm.train_transforms = 
dm.test_transforms = 
dm.val_transforms = 
model = LiteModel()

Trainer().fit(model, datamodule=dm)

The dataset will be downloaded and is placed in this hierarchy below.

Download imagenet dataset and place it accordingly since ImageNet dataset is too big of a file to download it on code.



Pretraining (Data Modules)

Data Modules (introduced in PyTorch Lightning 0.9.0) decouple the data from a model.

A Data Module is simply a collection of a training dataloader, val dataloader and test dataloader. It specifies how to

  • download/prepare data
  • train/val/test splits
  • transform

You can use it like this.

dm = MNISTDataModule('path/to/data')
model = LiteModel()

trainer = Trainer()
trainer.fit(model, datamodule=dm)

You can also use it manually.

dm = MNISTDataModule('/path/to/data')
for batch in dm.train_dataloader():
for batch in dm.val_dataloader():
for batch in dm.test_dataloader():

Contrastive Self-Supervised Learning Models


import pytorch_lightning as pl
from ConSSL.models.self_supervised import SimCLR
from ConSSL.datamodules import CIFAR10DatatModule
from ConSSL.models.self_supervised.simclr.transforms import (SimCLREvalDataTransform, SimCLRTrainDataTransform)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10')

# fit 
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)

#to finetune
python simclr_finetuner.py --gpus 4 --ckpt_path path/to/simclr/ckpt --dataset cifar10 --batch_size 64 --num_workers 8 --learning_rate 0.3 --num_epochs 100


Implementation Dataset Architecture Optimizer Batch size Epochs Linear Evaluation
Original CIFAR10 ResNet50 LARS 512 500 0.72
Mine CIFAR10 ResNet50 LARS-SGD 512 200 0.68
Original imagenet ResNet50 LARS 256 500 0.67
Mine imagenet ResNet50 LARS-SGD 256 200 0.64

to reproduce

cd code
# change the configuration setting in config.py
python cli.py pretrain
python cli.py linear_evaluation


# CLI command
# imagenet
python moco2_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet --meta_dir /path/to/folder/with/meta.bin/ --batch_size 512

ConSSL.models.self_supervised.MocoV2(base_encoder='resnet18', emb_dim=128, num_negatives=65536, encoder_momentum=0.999, softmax_temperature=0.07, learning_rate=0.03. momentum=0.9, weight_decay=0.0001, data_dir='./', batch_size=256, use_mlp=False, num_workers=8, *args, **kwargs)

base_encoder: torchvision model name or toch.nn.Module
emb_dim: feature dimension
num_negatives: queue size
encoder_momentum: moco momentum of updating key encoder
use_mlp: add an mlp to the encoder
from ConSSL.models.self_supervised import MocoV2 
model = MocoV2()
trainer = Trainer()
Implementation Dataset Architecture LR Batch size Epochs Linear Evaluation
Mine ImageNet ResNet50 Cosine 512 200 0.65


from ConSSL.callbacks.byol_updates import BYOLMAWeightUpdate

'''the exponential moving average weight update rule from BYOL.
Your model should have self.online_network, self.target_network'''

model = Model()
model.online_network = 
model.target_network = 

trainer = Trainer(callbacks=[BYOLMAWeightUpdate(initial_tau=0.996)])
# CLI command
python byol_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 512

ConSSL.models.self_supervised.BYOL(num_classes, learning_rate=0.2, weight_decay=1.5e-6, input_height=32,
batch_size=32, num_workers=0, warmup_epochs=10, max_epochs=1000, **kwargs)
Implementation Dataset Architecture LR Batch size Epochs Linear Evaluation
Original ImageNet ResNet50 Cosine 4096 300 0.72
Mine ImageNet ResNet50 Cosine 512 200 0.66


import pytorch_lightning as pl
from ConSSL.models.self_supervised import SwAV
from ConSSL.datamodules import STL10DataModule
from ConSSL.models.self_supervised.swav.transform import (SwAVTrainDataTransform, SwAVEvalDataTransform)
from ConSSL.transforms.dataset_normalization import stl10_normalization

# data 
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed

dm.train_transforms = SwAVTrainDataTransform(normalize=stl10_normalization())
dm.val_transforms = SwAVEvalDataTransform(normalize=stl10_normalization())

# model
model = SwAV(gpus=1, num_samples=dm.num_unlabeled_samples, dataset='stl10', batch_size=batch_size)

# fit 
trainer = pl.Trainer(precision=16)

swav_pretrain swav_lr

Implementation Dataset Architecture Optimizer Batch size Epochs Linear Evaluation
Mine STL10 ResNet50 LARS-SGD 128 100 0.86

to reproduce

cd code
# change the configuration setting in config.py
python cli.py pretrain
python cli.py linear_evaluation


Implementation Dataset Architecture Batch size Epochs Linear Evaluation
Original CIFAR10 ResNet18 512 800 0.91
Mine CIFAR10 ResNet18 512 300 0.82

Linear Evaluation Protocol

from pytorch_lightning as pl
from ConSSL.models.regression import LogisticRegression
from ConSSL.datamodules import ImagenetDataModule

imagenet = ImagenetDataModule(PATH)

# 224x224x3
pixels_per_image = 150528
model = LogisticRegression(input_dim=pixels_per_image, num_classes=1000)
model.prepare_data = imagenet.prepare_data
trainer = Trainer(gpus=2)
trainer.fit(model, imagenet.train_dataloader(batch_size=256), imagenet.val_dataloader(batch_size=256))

Semi-Supervised Learning

use imagenet subset from https://github.com/tensorflow/datasets/tree/master/tensorflow_datasets/image_classification

  1. Unfrozen Finetuning
from ConSSL.models.self_supervised import SimCLR
from ConSSL.models.regression import LogisticRregression

weight_path = 'checkpoint/path'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
resnet50 = simclr.encoder 
# don't call simclr.freeze()

classifier = LogisticRegresion()
for (x,y) in own_data:
 feats = resnet50(x)
 y_hat = classifier(feats)
  1. Freeze then Unfreeze
from ConSSL.models.self_supervised import SimCLR
from ConSSL.models.regression import LogisticRregression

weight_path = 'checkpoint/path'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
resnet50 = simclr.encoder 

classifier = LogisticRegression()
for epoch in epochs:
 for (x,y) in own_data:
  feats = resnet50(x)
  y_hat = classifier(feats)
  loss = cross_entropy_with_logits(y_hat, y)
 # unfreeze after 10 epochs 
 if epoch == 10:

Transfer Learning

from pytorch_lightning as pl
from ConSSL.models.regression import LogisticRregression
from ConSSL.datamodules import MNISTDataModule

dm = MNISTDataModule(num_workers=0, data_dir=tmpdir)

model= LogisticRegression(input_dim=28*28, num_classes=10, learning_rate=0.001)
model.prepare_data = dm.prepare_data
model.train_dataloader = dm.train_dataloader
model.val_dataloader = dm.val_dataloader
model.test_dataloader = dm.test_dataloader

trainer = pl.Trainer(max_epochs=200)


  • I use latest version of python 3 and python2 is not supported.
  • I use latest version of PyTorch, though tensorflow-gpu is necessary to launch tensorboard.


git clone --recurse-submodules (this repo)
cd $REPO_NAME/code
(use python >= 3.5)
pip install -r requirements.txt

When using docker

build & push & run

sudo ./setup-docker.sh

directory structure


Data Folder Structure

 cli.py : executable check_dataloading, training, evaluating script
 config.py: default configs
 ckpt.py: checkpoint saving & loading
 train.py : training python configuration file
 evaluate.py : evaluating python configuration file
 infer.py : make submission from checkpoint
 logger.py: tensorboard and commandline logger for scalars
 utils.py : other helper modules
 dataloader/ : module provides data loaders and various transformers
  load_dataset.py: dataloader for classification
  vision.py: image loading helper
 metric/ : accuracy and loss logging 


utils.prepare_batch: move to GPU and build target
ckpt.get_model_ckpt: load ckpt and substitue model weight and args
load_dataset.get_iterator: load data iterator {'train': , 'val': , 'test': }

How To Use

First check data loading

cd code
python3 cli.py check_dataloader


cd code
python3 cli.py train


cd code
python3 cli.py evaluate --ckpt_name=$CKPT_NAME
  • Substitute CKPT_NAME to your preferred checkpoint file, e.g., ckpt_name=model_name_simclr_ckpt_3/loss_0.4818_epoch_15
from ConSSL.callbacks.ssl_online import SSLOnlineEvaluator

''' attaches a MLP for fine-tuning using the standard self-supervised protocol'''
model = Model()
model.z_dim = # the representation dim
model.num_classes = # the number of classes in the model

online_eval = SSLOnlineEvaluator(z_dim=model.z_dim, num_classes=model.num_classes, dataset='imagenet')
# if the dataset if stl10, you need to get the labeled batch


A callback is a self-contained program that can be intertwined into a training pipeline.

from ConSSL.callbacks import import Callback

class MyCallback(Callback):
 def on_epoch_end(self, trainer, pl_module):
  # do something

The data monitoring callbacks allow you to log and inspect the distribution of data that passes through the training step and layers of the model.

from ConSSL.callbacks import TrainingDataMonitor
from pytorch_lightning import Trainer

monitor = TrainingDataMonitor(log_every_n_steps=25)

model = YourLightningModule()
trainer = Trainer(callbacks=[monitor])


