/Black-Box-Meta-Learning

Memory Augmented Neural Networks, a Black-Box Meta-Learner that uses an LSTMs for few-shot classification

Primary LanguagePython

Black-Box-Meta-Learning

This repo implements and trains a memory augmented neural networks, a black-box meta-learner that uses a recurrent neural network for few shot classification

This repository contains:

  1. The python code
  2. The config file
  3. CS330 HW1 file
  4. And the ReadMe file itself

Table of Contents

About

Dataset

The Omniglot data set is designed for developing more human-like learning algorithms. It contains 1623 different handwritten characters from 50 different alphabets. Each of the 1623 characters was drawn online via Amazon's Mechanical Turk by 20 different people. The Omniglot data set contains 50 alphabets. It is split into a background set of 30 alphabets and an evaluation set of 20 alphabets.

Model

A stacked 2 layered-LSTM model is employed. The inputs from the support set are concatenated with their true lables one-hot encoded. Where as the inputs from the query set are concatenated with all zeroes. The model is expected to predict the true labels of the query set. Shown below is a stacked LSTM model.
demo

More information on the training procedure could be found in HW1 of CS330. The hyper-parameters can be changed in the config file.

Update 1 (05-08-2021) : Included support for Bidirectional-LSTM. Change 'bi_dir' to "true" in the config file to enable BiLSTM.

To Run

Download the omniglot data here and save the downloaded folders in a folder titled 'omniglot'. Save the python code and config file in the same directory of 'omniglot'.


BlackBox.py
config.json
omniglot
│___  images_background
│___  images_evaluation    

Install the following libraries to run the code

torch
numpy
glob
PIL
matplotlib

Run BlackBox.py

python3 BlackBox.py

References

  1. This work is inspired by Stanford's CS 330: Deep Multi-Task and Meta Learning
  2. A similar implementation could be found here
  3. An awesome blog on Meta-Learning
  4. Stanford's lecture series on CS 330: Deep Multi-Task and Meta Learning in YouTube