/self_supervised

A Pytorch-Lightning implementation of self-supervised algorithms

Primary LanguagePythonMIT LicenseMIT

PyTorch-Lightning Implementation of Self-Supervised Learning Methods

This is a PyTorch Lightning implementation of the following self-supervised representation learning methods:

Supported datasets: ImageNet, STL-10, and CIFAR-10.

During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an sklearn linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as train_class_acc / valid_class_acc.

Installing

Make sure you're in a fresh conda or venv environment, then run:

git clone https://github.com/untitled-ai/self_supervised
cd self_supervised
pip install -r requirements.txt

Replicating our BYOL blog post

We found some surprising results about the role of batch norm in BYOL. See the blog post Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL) for more details about our experiments.

You can replicate the results of our blog post by running python train_blog.py. The cosine similarity between z and z' is reported as step_neg_cos (for negative examples) and step_pos_cos (for positive examples). Classification accuracy is reported as valid_class_acc.

Getting started with MoCo v2

To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration):

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import ModelParams

os.environ["DATA_PATH"] = "~/data"

params = ModelParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

For convenience, you can instead pass these parameters as keyword args, for example with model = SelfSupervisedMethod(batch_size=128).

VICReg

To train VICReg rather than MoCo v2, use the following parameters:

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import VICRegParams

os.environ["DATA_PATH"] = "~/data"

params = VICRegParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details.

BYOL

To train BYOL rather than MoCo v2, use the following parameters:

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import BYOLParams

os.environ["DATA_PATH"] = "~/data"

params = BYOLParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

SimCLR

To train SimCLR rather than MoCo v2, use the following parameters:

import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import SimCLRParams

os.environ["DATA_PATH"] = "~/data"

params = SimCLRParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

Note for multi-GPU setups: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs.

Evaluating a trained model

To train a linear classifier on the result:

import pytorch_lightning as pl
from linear_classifier import LinearClassifierMethod
linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt")
trainer = pl.Trainer(gpus=1, max_epochs=100)    

trainer.fit(linear_model)

Results on STL-10 and ImageNet

Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters.

Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set. This used 8 gpus with ddp and parameters:

hparams = ModelParams(
   encoder_arch="resnet50",
   shuffle_batch_norm=True,
   embedding_dim=2048,
   mlp_hidden_dim=2048,
   dataset_name="imagenet",
   batch_size=32,
   lr=0.03,
   max_epochs=200,
   transform_crop_size=224,
   num_data_workers=32,
   gather_keys_for_queue=True,
)

(the batch_size differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu training in ddp - the effective number is batch_size=256). Note that for ImageNet we suggest using val_percent_check=0.1 when calling pl.Trainer to reduce the time fitting the sklearn model.

All training options

All possible hparams for SelfSupervisedMethod, along with defaults:

class ModelParams:
    # encoder model selection
    encoder_arch: str = "resnet18"
    shuffle_batch_norm: bool = False
    embedding_dim: int = 512  # must match embedding dim of encoder

    # data-related parameters
    dataset_name: str = "stl10"
    batch_size: int = 256

    # MoCo parameters
    K: int = 65536  # number of examples in queue
    dim: int = 128
    m: float = 0.996
    T: float = 0.2

    # eqco parameters
    eqco_alpha: int = 65536
    use_eqco_margin: bool = False
    use_negative_examples_from_batch: bool = False

    # optimization parameters
    lr: float = 0.5
    momentum: float = 0.9
    weight_decay: float = 1e-4
    max_epochs: int = 320
    final_lr_schedule_value: float = 0.0

    # transform parameters
    transform_s: float = 0.5
    transform_apply_blur: bool = True

    # Change these to make more like BYOL
    use_momentum_schedule: bool = False
    loss_type: str = "ce"
    use_negative_examples_from_queue: bool = True
    use_both_augmentations_as_queries: bool = False
    optimizer_name: str = "sgd"
    lars_warmup_epochs: int = 1
    lars_eta: float = 1e-3
    exclude_matching_parameters_from_lars: List[str] = []  # set to [".bias", ".bn"] to match paper
    loss_constant_factor: float = 1

    # Change these to make more like VICReg
    use_vicreg_loss: bool = False
    use_lagging_model: bool = True
    use_unit_sphere_projection: bool = True
    invariance_loss_weight: float = 25.0
    variance_loss_weight: float = 25.0
    covariance_loss_weight: float = 1.0
    variance_loss_epsilon: float = 1e-04

    # MLP parameters
    projection_mlp_layers: int = 2
    prediction_mlp_layers: int = 0
    mlp_hidden_dim: int = 512

    mlp_normalization: Optional[str] = None
    prediction_mlp_normalization: Optional[str] = "same"  # if same will use mlp_normalization
    use_mlp_weight_standardization: bool = False

    # data loader parameters
    num_data_workers: int = 4
    drop_last_batch: bool = True
    pin_data_memory: bool = True
    gather_keys_for_queue: bool = False

A few options require more explanation:

  • encoder_arch can be any torchvision model, or can be one of the ResNet models with weight standardization defined in ws_resnet.py.

  • dataset_name can be imagenet, stl10, or cifar10. os.environ["DATA_PATH"] will be used as the path to the data. STL-10 and CIFAR-10 will be downloaded if they do not already exist.

  • loss_type can be ce (cross entropy) with one of the use_negative_examples to correspond to MoCo or ip (inner product) with both use_negative_examples=False to correspond to BYOL. It can also be bce, which is similar to ip but applies the binary cross entropy loss function to the result. Or it can be vic for VICReg loss.

  • optimizer_name, currently just sgd or lars.

  • exclude_matching_parameters_from_lars will remove weight decay and LARS learning rate from matching parameters. Set to [".bias", ".bn"] to match BYOL paper implementation.

  • mlp_normalization can be None for no normalization, bn for batch normalization, ln for layer norm, gn for group norm, or br for batch renormalization.

  • prediction_mlp_normalization defaults to same to use the same normalization as above, but can be given any of the above parameters to use a different normalization.

  • shuffle_batch_norm and gather_keys_for_queue are both related to multi-gpu training. shuffle_batch_norm will shuffle the key images among GPUs, which is needed for training if batch norm is used. gather_keys_for_queue will gather key projections (z' in the blog post) from all gpus to add to the MoCo queue.

Training with custom options

You can train using any settings of the above parameters. This configuration represents the settings from BYOL:

hparams = ModelParams(
 prediction_mlp_layers=2,
 mlp_normalization="bn",
 loss_type="ip",
 use_negative_examples_from_queue=False,
 use_both_augmentations_as_queries=True,
 use_momentum_schedule=True,
 optimizer_name="lars",
 exclude_matching_parameters_from_lars=[".bias", ".bn"],
 loss_constant_factor=2
)

Or here is our recommended way to modify VICReg for CIFAR-10:

from model_params import VICRegParams

hparams = VICRegParams(
   dataset_name="cifar10",
   transform_apply_blur=False,
   mlp_hidden_dim=2048,
   dim=2048,
   batch_size=256,
   lr=0.3,
   final_lr_schedule_value=0,
   weight_decay=1e-4,
   lars_warmup_epochs=10,
   lars_eta=0.02
)