PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. Including support for:
- Distributed data parallel training
- Global batch normalization
- LARS (Layer-wise Adaptive Rate Scaling) optimizer.
One additional feature in this implementation is another mask function wich also eliminates hardest unequal example from los function.
git clone https://github.com/spijkervet/SimCLR.git && cd SimCLR
wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
conda activate simclr
python linear_evaluation.py --dataset=STL10 --model_path=. --epoch_num=100 --resnet resnet50
Simply run the following to pre-train a ResNet encoder using SimCLR on the CIFAR-10 dataset:
python main.py --dataset CIFAR10
With distributed data parallel (DDP) training:
CUDA_VISIBLE_DEVICES=0 python main.py --nodes 2 --nr 0
CUDA_VISIBLE_DEVICES=1 python main.py --nodes 2 --nr 1
CUDA_VISIBLE_DEVICES=2 python main.py --nodes 2 --nr 2
CUDA_VISIBLE_DEVICES=N python main.py --nodes 2 --nr 3
python linear_evaluation.py --model_path=. --epoch_num=100
The LARS optimizer is implemented in modules/lars.py
. It can be activated by adjusting the config/config.yaml
optimizer setting to: optimizer: "LARS"
. It is still experimental and has not been thoroughly tested.
For distributed training (DDP), use for every process in nodes, in which N is the GPU number you would like to dedicate the process to:
CUDA_VISIBLE_DEVICES=0 python main.py --nodes 2 --nr 0
CUDA_VISIBLE_DEVICES=1 python main.py --nodes 2 --nr 1
CUDA_VISIBLE_DEVICES=2 python main.py --nodes 2 --nr 2
CUDA_VISIBLE_DEVICES=N python main.py --nodes 2 --nr 3
--nr
corresponds to the process number of the N nodes we make available for training.
To test a trained model, make sure to set the model_path
variable in the config/config.yaml
to the log ID of the training (e.g. logs/0
).
Set the epoch_num
to the epoch number you want to load the checkpoints from (e.g. 40
).
python linear_evaluation.py
or in place:
python linear_evaluation.py --model_path=./save --epoch_num=40
The configuration of training can be found in: config/config.yaml
. An example config.yaml
file:
# train options
batch_size: 256
workers: 16
start_epoch: 0
epochs: 40
dataset_dir: "./datasets"
# model options
resnet: "resnet18"
normalize: True
projection_dim: 64
# loss options
temperature: 0.5
# reload options
model_path: "logs/0" # set to the directory containing `checkpoint_##.tar`
epoch_num: 40 # set to checkpoint number
# logistic regression options
logistic_batch_size: 256
logistic_epochs: 100
torch
torchvision
tensorboard
pyyaml