/Semi-Supervised-Learning-GAN

Semi-supervised Learning GAN

Primary LanguageJupyter NotebookMIT LicenseMIT

Semi-supervised Learning with Generative Adversarial Networks (GANs)

Modern deep learning classifiers require a large volume of labeled samples to be able to generalize well. GANs have shown a lot of potential in semi-supervised learning where the classifier can obtain good performance with very few labeled data (Salimans et. al., 2016).

Overview

To train a -class classifier with a small number of labeled samples, discriminator (D) in a GAN's game should be replaced with a -classiifer where it receives a data point as input and outputs a -dimensional vector of logits . These logits can then be transferred into class probabilities, where:

provides the probability that is fake.

provides the probability that is real and belongs to class . Now, the loss of discriminator can be written as:

where:

is the standard supervised learning loss function given that the data is real and:

is the standard GAN's game-value where:

.

Now, let's denote the activations on an intermediate layer of discriminator. The feature matching loss of generator can be defined as:

Feature matching has shown a lot of potential in semi-supervised learning. The goal of feature matching is to push the generator to generate data that matches the statistics of real data. Discriminator is used to specify those statistics as it naturally learns to find features that are most discriminative of real data versus data generated by the current model.

In this code, I combined with the known generator cost that maximizes the log-probability of discriminator being mistaken:

.

So, the loss of generator can be written as:

Results

Table below shows cross-validation accuracy of semi-supervised learning GAN for 1000 epochs when 10% and 100% of MNIST data is labeled.

10% labeled data 100% labeled data
0.9255 0.945

Figure below shows cross-validation accuracy for 1000 epochs when 10% of data is labeled. As can be seen here, training has not yet reached a plateau which indicates further training could provide higher accuracy.

Figures below show some generated samples at different epochs of training when 10% of data is labeled:

Reference:

Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. (2016). Improved Techniques for Training GANs. In advances in Neural Information Processing Systems (NIPS), pages 2226-2234 (http://papers.nips.cc/paper/6125-improved-techniques-for-training-gans.pdf)