This repository contains the implementations of many meta-learning algorithms to solve the few-shot learning problem in PyTorch, including:
- Model-Agnostic Meta-Learning (MAML)
- Probabilistic Model-Agnostic Meta-Learning (PLATIPUS)
- Prototypical Networks (protonet)
- Bayesian Model-Agnostic Meta-Learning (BMAML) (without Chaser loss)
- Amortized Bayesian Meta-Learning
- Uncertainty in Model-Agnostic Meta-Learning using Variational Inference (VAMPIRE)
- PyTorch 1.8.1 or above (which introduces a new module called "Lazy", corresponding to the Dense layer in Tensorflow)
- higher
What does "functional" mean? It is similar to the module torch.nn.functional
, where the parameters can be handled explicitly, not implicitly as in PyTorch torch.nn.Sequential()
. For example:
# conventional with implicitly-handled parameter
y = net(x) # parameters are handled by PyTorch implicitly
# functional form
y = functional_net(x, params=theta) # theta is the parameter
With the current PyTorch, one needs to manually implement the "functional" form of every component of the model of interest via torch.nn.functional
. This is, however, inconvenient when changing network architecture.
Fortunately, Facebook Research has developed higher - a library that can easily convert any "conventional" neural network into its "functional" form to handle parameter explicitly. For example:
# define a network
resnet18 = torchvision.models.resnet18(pretrain=False)
# get its parameters
params = list(resnet18.paramters())
# convert the network to its functional form
f_resnet18 = higher.patch.make_functional(restnet18)
# forward with functional and handling parameter explicitly
y1 = f_resnet18.forward(x=x1, params=params)
# update parameter
new_params = update_parameter(params)
# forward on different data with new paramter
y2 = f_resnet18.forward(x=x2, params=new_params)
Hence, we only need to load or specify the "conventional" model written in PyTorch without manually re-implementing its "functional" form. A few common models are implemented in CommonModels.py
.
Although higher provides convenient APIs to track gradients, it does not allow us to use the "first-order" approximate, resulting in more memory and longer training time. I have created a work-around solution to enable the "first-order" approximation, and controlled this by setting --first-order=True
when running the code.
Majority of the implementation is based on the abstract base class MLBaseClass.py
, and each of the algorithms is written in a separated class. The main program is specified in main.py
. PLATIPUS is slightly different since the algorithm mixes between training
and validation
subset, and hence, implemented in a separated file.
The implementation is mainly in the abstract base class MLBaseClass.py
with some auxilliary classes and functions in _utils.py
. The operation principle of the implementation can be divided into 3 steps:
Recall the nature of the meta-learning as:
θ → w → y ← x,
where θ denotes the parameter of the hyper-net, w is the base-model parameter, and (x, y) is the data.
The implementation is designed to follow this generative process, where the hyper-net will generate the base-net. It can be summarized in the following pseudo-code:
# initialization
base_net = ResNet18() # base-net
# convert conventional functional
f_base_net = torch_to_functional_module(module=base_net)
# make hyper-net from the base-net
hyper_net = hyper_net_cls(base_net=base_net)
# the hyper-net generates the parameter of the base-net
base_net_params = hyper_net.forward()
# make prediction
y = f_base_net(x, params=base_net_params)
- MAML: the hyper-net is the initialization of the base-net. Hence, the generative process follows identity operator, and hence,
hyper_net_cls
is defined as the classIdentityNet
in_utils.py
. - ABML and VAMPIRE: the base-net parameter is a sample drawn from a diagonal Gaussian distribution parameterized by the meta-parameter. Hence, the hyper-net is designed to simulate this sampling process. In this case,
hyper_net_cls
is the classNormalVariationalNet
in_utils.py
. - Prototypical network is different from the above algorithms due to its metric-learning nature. In the implementation, only one network is used as
hyper_net
, while thebase_net
is set toNone
.
Why is it such a complicated implementation? It is to allow us to share the common procedures of many meta-learning algorithms via the abstract base class MLBaseClass
. If it is not cleared to you, please open an issue or send me an email. I am happy to discuss to improve the readability of the code further.
There are 2 sub-functions corresponding to MAML-like algorithms and protonet.
The idea is simple:
- Generate the parameter(s) of the base-net from the hyper-net
- Use the generated base-net parameter(s) to calculate loss on training (also known as support) data
- Minimize the loss w.r.t. the parameter of the hyper-net
- Return the (task-specific) hyper-net (assigned to
f_hyper_net
) for that particular task
Calculate and return the prototypes in the embedding space
The task-specific hyper-net, or f_hyper_net
in the case of MAML-like algorithms, or the prototypes in the case of prototypical networks, are used to predict the labels of the data in the validation subset.
- In training, the predicted labels are used to calculate the loss, and the parameter of the hyper-net is updated to minimize that loss.
- In testing, the predicted labels are used to compute the prediction accuracy.
Note that ABML is slightly different since it also includes the loss made by the task-specific hyper-net on the training subset. In addition, it places prior on the parameter of the hyper-net. This is implemented in the methods loss_extra()
and loss_prior
, respectively.
Currently, regression has not been implemented yet.
Omniglot and mini-ImageNet are the two datasets considered. They are organized following the torchvision.datasets.ImageFolder
.
Dataset
│__alphabet1_character1 (or class1)
|__alphabet2_character2 (or class2)
...
|__alphabetn_characterm (or classz)
You can modify the transformations
in main.py
to fit your need about image sizes or image normalization.
The implementation replies on torch.utils.data.DataLoader
with customized EpisodeSampler.py
to generate data for each task. The implementation also support loading multiple datasets by appending --datasource dataset_name --datasource another_dataset_name
in the input arguments.
If the original structure of Omniglot (train -> alphabets -> characters) is desired, you might need to append the list of all alphabet names to config['datasource']
.
To run, copy and paste the command at the beginning of each algorithm script and change the configurable parameters (if needed).
To test, simply specify which saved model is used via variable resume_epoch
and replace --train
by --test
at the end of the commands found on the top of main.py
.
Tensorboard is also integrated into the implementation. Hence, you can open it and monitor the training on your favourite browser:
tensorboard --logdir=<your destination folder>
Then open the browser and see the training progress at:
http://localhost:6006/
If you only need to run MAML and feel that my implementation is complicated, torch-meta is a worthy repository to take a look. The difference between torch-meta and mine is to extend the implementation to other algorithms, such as VAMPIRE and ABML.
If you feel this repository useful, please give a ⭐ to motivate my work.
In addition, please consider to give a ⭐ to the higher repository developed by Facebook. Without it, we still suffer from the arduous re-implementation of model "functional" form.