This is an independent reimplementation of the Supervised Contrastive Learning paper.
Go here if you want to go to an implementation from one the author in torch
and here for the official in tensorflow.
The goal of this repository is to provide a straight to the point implementation and experiment to answer
specific question.
Architecture | Cross-entropy | Cross-entropy + Auto-augment | SupContrast + Auto-augment |
---|---|---|---|
ResNet20 | 91.25 % * | 92.71 % | 93.51 % |
Architecture | Cross-entropy | Cross-entropy + Auto-augment | SupContrast + Auto-augment |
---|---|---|---|
ResNet20 | 66.05 | 66.28 | 68.42 % |
After creating a virtual environment with your software of choice just run
pip install -r requirements.txt
A simple run of the following command will give you available script option. Default values will help your replicate my results
python train.py -h
Is a contrastive epoch taking 50 % more time than a cross-entropy one ?
Yes this claim seems inline with mine and official implementation
Is the use of heavy data-augmentation necessary ?
Seems like it. A run without hyper parameter tuning and without AutoAugment but with the same data-augmentation as the original
ResNet paper yielded a 5 % drop in accuracy compared to the cross-entropy. Although in the paper other data augmentation policies are close
behind it contrastive approaches seem to not need sophisticated data augmentation strategies. See original SimCLR paper
Do you need few epochs to train the decoder on the embedding ?
Yes definitely. Only 1-2 epochs of cross-entropy on the embedding gave a model close to the best accuracy. Better
configuration were found after tens of epochs but it was usually only better in the 1e-1 accuracy range.
What is the number of contrastive epoch needed ?
The number of epochs necessary to have a good embedding after the contrastive step is higher than a regular cross-entropy.
I did 400/500 epochs while in the official-github the default value is at 1000 epochs and in the paper 700 epochs is mentioned for
ILSVRC-2012. For my test with cross entropy it was at most 700 epochs.
Why the loss never reaches zero ?
The supervised contrastive loss defined in the paper will converge to a constant value, which is batch size dependant.
The loss as it is described in the paper is analogous to the Tammes problem
where each clusters where projections of a particular class land repel other clusters. Although it is unsolved for such
high dimension of 128, an approximate solution over dataset statistics can be easily calculated. This could be computationally
intensive when taking in random configurations at each batch but could be avoided with a sampler given back the same
labels configuration. I suspect it might be an easy avenue to reduce the number of epochs needed before convergence.
Will this work for very small network ?
This approach seems to work also on small network and is one of the addition of this repo. As you can see ResNet-20 results above where
this approach was better than cross entropy and the model is only
.3 M parameters. Which is drastically lower than the 20 + M for ResNet-50 on ILSVRC-2012 and the official github.
Would I recommend using this approach for your specific task ? And will I use it ?
One thing that I do like and is the main selling point of this technique is exchanging the boring process of hyper
parameter tuning for computation. All result presented here only needed one training attempt.
You just need to decrease the learning rate along the way, whereas with cross-entropy I had to rerun the experiment
on average 3 times with different learning rate strategy to get the best result shown.
The other thing that seems to emerge from this paper is that it seems that this method is one of the best in a tabula rasa
approach. But you can look also in GradAug, CutMix or Bag of tricks. So it might be a great fit when you are dealing with a problem with non standard images i.e no ILSRVC-2012
like dataset available to pretrained on and it is difficult to collect a ton of unlabelled data also. In the case where you can
gather a lot of unlabelled data you might have better result with semi-supervised approach like
SimCLRv2 or BYOL. But I guess if
you are here you know about them.