/keras-triplet-loss-mnist

Keras model trained using semi-hard triplet Loss (tensorflow function) on MNIST

Primary LanguageJupyter NotebookMIT LicenseMIT

Keras-Triplet-loss-MNIST

Train a Keras model using the Tensorflow function of semi-hard triplet loss, on the MNIST dataset.

alt text

Disclaimer1: the major contribution of this script lies in the combination of the tensorflow function with the Keras Model API. For this reason I had to define the function (as well as its support functions) locally.

Please check out the Tensorflow documentation page for the function here.

Make sure you are using the same tensorflow version!

Python script can be found in the ipython notebook 'Triplet_loss_KERAS_semi_hard_from_TF.ipynb'

Resources

  • Keras 2.1.6
  • Tensorflow 1.8.0

Triplet Loss explained:

alt text

Figures taken from paper introducing Facenet(1). Figure 2 represents the general idea of encoding images into a series of numbers much smaller than the image's size.

Figure 3 presents the manner of training the network to differentiate between intra-class and inter-class cases. By pairing the images into triplet pairs of Anchor-Positive and Anchor-Negative, the network learns the distribution of images from each class with respect to all other classes.

The loss function is defined as:

alt text

Where d(A,P) and d(A,N) represent the Euclidean distances between the Anchor and the Positive and Negative pairs. margin is a parameter helping the network learning a specific distance between positive and negative samples (using the anchor).

Positive and Negative pairs are important to training the network correctly. Ideally the Anchor-Positive templates should have large(r) distance between them whereas the Anchor-Negative templates should have small(er) distance. These represent HARD examples

Preparing the network

First, the network architecture was defined, with an Input layer of the same shape as the input image (28x28) and an Output layer of size (64), representing the embedding.

def create_base_network(image_input_shape, embedding_size):
    """
    Base network to be shared (eq. to feature extraction).
    """
    input_image = Input(shape=image_input_shape)

    x = Flatten()(input_image)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(embedding_size)(x)

    base_network = Model(inputs=input_image, outputs=x)
    plot_model(base_network, to_file='base_network.png', show_shapes=True, show_layer_names=True)
    return base_network

alt text

We then define the Model such that the Triplet Loss function receives all the embeddings from each batch, as well as their corresponding labels (used for determining the best triplet-pairs). This is done by defining an input layer for the labels and then concatenating it to the embeddings.

  base_network = create_base_network(input_image_shape, embedding_size)

  input_images = Input(shape=input_image_shape, name='input_image') # input layer for images
  input_labels = Input(shape=(1,), name='input_label')    # input layer for labels
  embeddings = base_network([input_images])               # output of network -> embeddings
  labels_plus_embeddings = concatenate([input_labels, embeddings])  # concatenating the labels + embeddings

  # Defining a model with inputs (images, labels) and outputs (labels_plus_embeddings)
  model = Model(inputs=[input_images, input_labels],
                outputs=labels_plus_embeddings)

alt text

Training

In order to train, we need to define some 'dummy' embeddings to pass as ground truth (y) values

opt = Adam(lr=0.0001)  # choose optimiser. RMS is good too!
model.compile(loss=triplet_loss_lol, 
              optimizer=opt)

filepath = "semiH_trip_MNIST_v13_test_ep{epoch:02d}_BS%d.hdf5" % batch_size
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=False, period=25)
callbacks_list = [checkpoint]

# Uses 'dummy' embeddings + dummy gt labels; removed as soon as they enter the loss function...
dummy_gt_train = np.zeros((len(x_train), embedding_size + 1))
dummy_gt_val = np.zeros((len(x_val), embedding_size + 1))

x_train = np.reshape(x_train, (len(x_train), x_train.shape[1], x_train.shape[1], 1))
x_val = np.reshape(x_val, (len(x_val), x_train.shape[1], x_train.shape[1], 1))
H = model.fit(x=[x_train,y_train],
            y=dummy_gt_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=([x_val, y_val], dummy_gt_val),
            callbacks=callbacks_list)

Visualizing separation of classes

We need to:

  1. Make an empty network
# creating an empty network
testing_embeddings = create_base_network(input_image_shape,
                                         embedding_size=embedding_size)
# embeddings before training...
x_embeddings_before_train = testing_embeddings.predict(np.reshape(x_test, (len(x_test), 28, 28, 1)))
  1. Loop over the trained model and copy weights
# Grabbing the weights from the trained network
for layer_target, layer_source in zip(testing_embeddings.layers, model.layers[2].layers):
  weights = layer_source.get_weights()
  layer_target.set_weights(weights)
  del weights
  1. Obtain predictions (embeddings) for test set
x_embeddings = testing_embeddings.predict(np.reshape(x_test, (len(x_test), 28, 28, 1)))
  1. Obtain PCA decomposition
dict_embeddings = {}
dict_gray = {}
test_class_labels = np.unique(np.array(y_test))

pca = PCA(n_components=no_of_components)
decomposed_embeddings = pca.fit_transform(x_embeddings)
decomposed_gray = pca.fit_transform(x_embeddings_before_train)
  1. Visualize the separation... alt text

I hope this script will be helpful to anyone that wants to use Triplet Loss with Keras

References and Other resources:

(1) F. Schroff and J. Philbin, “FaceNet: A Unified Embedding for Face Recognition and Clustering,” in Proceedings of the IEEE conference on computer vision and pattern recognition (CVPR), 2015, pp. 815–823. arxiv link

(2) Blog post explaining Triplet loss very well. Their github page

(3) A. Hermans, L. Beyer, and B. Leibe, “In Defense of the Triplet Loss for Person Re-Identification,” 2017. Loss arxiv paper

(4) Semi-hard Triplet Loss function (tensorflow) doc page