/triplet-loss-pytorch

Highly efficient PyTorch version of the Semi-hard Triplet loss ⚡️

Primary LanguagePythonApache License 2.0Apache-2.0

Triplet SemiHardLoss

PyTorch semi hard triplet loss. Based on tensorflow addons version that can be found here. There is no need to create a siamese architecture with this implementation, it is as simple as following main_train_triplet.py cnn creation process!

The triplet loss is a great choice for classification problems with N_CLASSES >> N_SAMPLES_PER_CLASS. For example, face recognition problems.

The CNN architecture we use with triplet loss needs to be cut off before the classification layer. In addition, a L2 normalization layer has to be added.

Results on MNIST

I tested the triplet loss on the MNIST dataset. We can't compare directly to TF addons as I didn't run the experiment but this could be interesting from the point of view of performance. Here are the training logs if you want to compare results. Accuracy is not relevant and shouldn't be there as we are not training a classification model.

Phase 1

First we train last layer and batch normalization layers, getting close to 0.079 validation loss.

Phase 2

Finally, unfreezing all the layers it is possible to get close to 0.05 with enough training and hyperparmeter tuning.

Test

In order to test, there are two interesting options, training a classification model on top of the embeddings and plotting the train and test embeddings to see if same categories cluster together. The following figure contains the original 10,000 validation samples.

TSNE

We get an accuracy around 99.3% on validation by training a Linear SVM or a simple kNN. This repository is not focused on maximizing this accuracy by tweaking data augmentation, arquitecture and hyperparameters but on providing an effective implementation of triplet loss in torch. For more info on the state-of-the-art results on MNIST check out this amazing kaggle discussion.

Contact me with any question: alfonmedela@gmail.com | alfonsomedela.com

ENJOY IT!

Donations ₿

BTC Wallet: 1DswCAGmXYQ4u2EWVJWitySM7Xo7SH4Wdf

IMPORTANT

If you're using fastai library, it will return an error when predicting the embeddings with learn.predict. It internally knows that your data has N classes and if the embedding vector has M dimensions, beeing M>N, and the predicted highest value is larger than N, that class does not exist and returns an error. So either create your prediction function or make a simple modification of the source code that will modify self.classes list length.