/prototypical-networks-omniglot

An implementation of "Prototypical Networks for Few-shot Learning" on a notebook in Pytorch

Primary LanguageJupyter Notebook

Prototypical Networks on the Omniglot Dataset

An implementation of "Prototypical Networks for Few-shot Learning" on a notebook in Pytorch

I. Prototypical Networks

Prototypical Networks were introduced by Snell et al. in 2017 (https://arxiv.org/abs/1703.05175). They started from a pre-existing architecture called Matching Networks introduced in a previous paper (https://arxiv.org/abs/1606.04080). They are both part of a broader family of algorithms called Metric Learning Algorithms, and the success of these networks is based on their capacity to understand the similarity relationship among samples.

"Our approach is based on the idea that there exists an embedding in which points cluster around a single prototype representation for each class." claim the authors of the original paper Prototypical Networks for Few-shot Learning)

In other words, there exists a Mathematical representation of the images, called the embedding space, in which images of the same class gather in clusters. The main advantage of working in that space is that two images that look the same will be close to each other, and two images that are completely different will be far away from each other.

Clusters in the embedding space

Here the term "close" refers to a distance metric that needs to be defined. The cosine distance or the Euclidean distance are usually taken.

Unlike typical Deep Learning architecture, Prototypical Networks do not classify the image directly, and instead learn its mapping in the embedding space. To do so, the algorithm does several “loops” called episodes. Each episode is designed to mimic the Few-shot task. Let’s describe in detail one episode in training mode:

Notations:

In Few-shot classification, we are given a dataset with few images per class. Nc classes are randomly picked, and for each class we have two sets of images: the support set (size Ns) and the query set (size Nq).

Representation of one sample

Step 1: embed the images

First, we need to transform the images into vectors. This step is called the embedding, and is performed thanks to an "Image2Vector" model, which is a Convolutional Neural Network (CNN) based architecture.

Step 2: compute class prototypes

This step is similar to K-means clustering (unsupervised learning) where a cluster is represented by its centroid. The embeddings of the support set images are averaged to form a class prototype.

v(k) is the prototype of class k.

Step 3: compute distance between queries and prototypes

This step consists in classifying the query images. To do so, we compute the distance between the images and the prototypes. Metric choice is crucial, and the inventors of Prototypical Networks must be credited to their choice of distance metric. They noticed that their algorithm and Matching Networks both perform better using Euclidean distance than when using cosine distance.

Cosine distance Euclidean distance

Once distances are computed, a softmax is performed over distances to the prototypes in the embedding space, to get probabilities.

Step 4: classify queries

The class with higher probability is the class assigned to the query image.

Step 5: compute the loss and backpropagate

Only in training mode. Prototypical Networks use log-softmax loss, which is nothing but log over softmax loss. The log-softmax has the effect of heavily penalizing the model when it fails to predict the correct class, which is what we need.

Pros and Cons of Prototypical Networks

Pros Cons
Easy to understand Lack of generalization
Very "visual" Only use mean to decide prototypes, and ignore variance in support set
Noise resistant thanks to mean prototypes
Can be adapted to Zero-shot setting

II. The Omniglot Dataset

The Omniglot dataset is a benchmark dataset in Few-shot Learning. It contains 1,623 different handwritten characters from 50 different alphabets. The dataset can be found in this repository. I used the images_background.zip and the images_evaluation.zip files.

III. Implementation of ProtoNet for Omniglot

As suggested in the official paper, to increase the number of classes, all the images are rotated by 90, 180 and 270 degrees. Each rotation resulting in an additional class, so the total number of classes is now 6,492 (1,623 * 4). The training set contains images of 4,200 classes while the test set contains images of 2,292 classes.

The embedding part takes a (28x28x3) image and returns a column vector of length 64. The image2vector function is composed of 4 modules. Each module consists of :

  • A convolutional layer
  • A batch normalization
  • A ReLu activation function
  • A 2x2 max pooling layer.

Embedding CNNs

The chosen optimizer is Adam. The initial learning rate of 10−3, is cut in half at every epoch.

The model was trained on 5 epochs of 2,000 episodes each and was tested on 1,000 episodes. A new sample was randomly picked in the training set at each episode.

RESULTS

I tried to reproduce the results of the paper.
Training settings: 60 classes, 1 or 5 support points and 5 query points per class.
Testing settings: 5-way and 20-way scenarios, same number of support and query points than during training.

5-way 20-way
1-shot 5-shot 1-shot 5-shot
Obtained 98.8% 99.8% 96.1% 99.2%
Paper 98.8% 99.7% 96.0% 98.9%

I obtained similar results than the original paper, slightly better in some cases. This may be due to the sampling strategy which is not specified in the paper. I used random sampling at each episode.