/SimCLR-newmask

Implementation of SimCLR paper"A Simple Framework for Contrastive Learning of Visual Representations" plus a new masking approach.

Primary LanguagePython

SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. Including support for:

  • Distributed data parallel training
  • Global batch normalization
  • LARS (Layer-wise Adaptive Rate Scaling) optimizer.

Link to paper

One additional feature in this implementation is another mask function wich also eliminates hardest unequal example from los function.

Quickstart (fine-tune linear classifier)

git clone https://github.com/spijkervet/SimCLR.git && cd SimCLR
wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
conda activate simclr
python linear_evaluation.py --dataset=STL10 --model_path=. --epoch_num=100 --resnet resnet50

Training ResNet encoder:

Simply run the following to pre-train a ResNet encoder using SimCLR on the CIFAR-10 dataset:

python main.py --dataset CIFAR10

Distributed Training

With distributed data parallel (DDP) training:

CUDA_VISIBLE_DEVICES=0 python main.py --nodes 2 --nr 0
CUDA_VISIBLE_DEVICES=1 python main.py --nodes 2 --nr 1
CUDA_VISIBLE_DEVICES=2 python main.py --nodes 2 --nr 2
CUDA_VISIBLE_DEVICES=N python main.py --nodes 2 --nr 3

python linear_evaluation.py --model_path=. --epoch_num=100

LARS optimizer

The LARS optimizer is implemented in modules/lars.py. It can be activated by adjusting the config/config.yaml optimizer setting to: optimizer: "LARS". It is still experimental and has not been thoroughly tested.

For distributed training (DDP), use for every process in nodes, in which N is the GPU number you would like to dedicate the process to:

CUDA_VISIBLE_DEVICES=0 python main.py --nodes 2 --nr 0
CUDA_VISIBLE_DEVICES=1 python main.py --nodes 2 --nr 1
CUDA_VISIBLE_DEVICES=2 python main.py --nodes 2 --nr 2
CUDA_VISIBLE_DEVICES=N python main.py --nodes 2 --nr 3

--nr corresponds to the process number of the N nodes we make available for training.

Testing

To test a trained model, make sure to set the model_path variable in the config/config.yaml to the log ID of the training (e.g. logs/0). Set the epoch_num to the epoch number you want to load the checkpoints from (e.g. 40).

python linear_evaluation.py

or in place:

python linear_evaluation.py --model_path=./save --epoch_num=40

Configuration

The configuration of training can be found in: config/config.yaml. An example config.yaml file:

# train options
batch_size: 256
workers: 16
start_epoch: 0
epochs: 40
dataset_dir: "./datasets"

# model options
resnet: "resnet18"
normalize: True
projection_dim: 64

# loss options
temperature: 0.5

# reload options
model_path: "logs/0" # set to the directory containing `checkpoint_##.tar` 
epoch_num: 40 # set to checkpoint number

# logistic regression options
logistic_batch_size: 256
logistic_epochs: 100

Dependencies

torch
torchvision
tensorboard
pyyaml