/lightning-maml

MAML Implementation using Pytorch-lightning

Primary LanguagePythonMIT LicenseMIT

Pytorch Lightning MAML Implementation

PyTorch Lightning Conf: hydra Logging: wandb Code style: black

This repository is the reimplementation of MAML (Model-Agnostic Meta-Learning) algorithm. Differentiable optimizers are handled by Higher library and NN-template is used for structuring the project. The default settings are used for training on Omniglot (5-way 5-shot) problem. It can be easily extended for other few-shot datasets thanks to Torchmeta library.

Quickstart

On Local Machine

  1. Download and install dependencies
git clone https://github.com/rcmalli/lightning-maml.git
cd ./lightning-maml/
pip install -r requirements.txt
  1. Create .env file containing the info given below using your own Wandb. ai account to track experiments. You can use .env.template file.
export DATASET_PATH="/your/project/root/data/"
export WANDB_ENTITY="USERNAME"
export WANDB_API_KEY="KEY"
  1. Run the experiment
python3 src/run.py train.pl_trainer.gpus=1

On Google Colab

Google Colab

Results

Omniglot (5-way 5-shot)

Few-shot learning using this dataset is easy task to overfit or learn for MAML algorithm.

Metatrain Metavalidation
Algorithm Model inner_steps inner accuracy outer accuracy inner accuracy outer accuracy
MAML OmniConv 1 0.992 0.992 0.98 0.98
MAML OmniConv 5 1.0 1.0 1.0 1.0

Customization

Inside 'conf' folder, you can change all the settings depending on your problem or dataset. The default parameters are set for Omniglot dataset. Here are some examples for customization:

Debug on local machine without GPU

python3 src/run.py train.pl_trainer.gpus=0 train.pl_trainer.fast_dev_run=true

Running more inner_steps and more epochs

python3 src/run.py train.pl_trainer.gpus=1  train.pl_trainer.max_epochs=1000 \
data.datamodule.num_inner_steps=5

Running weep of multiple runs

python3 src/run.py train.pl_trainer.gpus=1 data.datamodule.num_inner_steps=5,10,20 -m

Using different dataset from Torchmeta

If you want to try a different dataset (ex. MiniImageNet), you can copy default.yaml file inside conf/data to miniimagenet.yaml and edit these lines :

datamodule:
  _target_: pl.datamodule.MetaDataModule

  datasets:
    train:
      _target_: torchmeta.datasets.MiniImagenet
      root: ${env:DATASET_PATH}
      meta_train: True
      download: True

    val:
      _target_: torchmeta.datasets.MiniImagenet
      root: ${env:DATASET_PATH}
      meta_val: True
      download: True

    test:
      _target_: torchmeta.datasets.MiniImagenet
      root: ${env:DATASET_PATH}
      meta_test: True
      download: True

# you may need to update data augmentation and preprocessing steps also!!!

Run the experiment as follows:

python3 src/run.py data=miniimagenet

Implementing a different meta learning algorithm

If you plant to implement a new variant of MAML algorithm (for example MAML++) you can start by extending default lightning module and its step function.

Notes

There are few required modifications run meta-learning algorithm using pytorch-lightning as high-level library

  1. In supervised learning we have M mini-batches for each epoch. However, we have N tasks for single meta-batch in meta learning settings. We have to set our dataloader length to 1 otherwise, the dataloader will indefinitely sample from the dataset.

  2. Apart from traditional test phase of supervised learning, we need gradient computation also in test phase. Currently, pytorch-lightning does not allow you to enable gradient computation by settings, you have to add single line to your beginning of test and validation steps as following:

     torch.set_grad_enabled(True)
  3. In MAML algorithm, we have two different optimizers to train our model. Inner optimizer must be differentiable and outer optimizer should update model using updated weights inside inner iteration from support set and updates from query set. In Pytorch-lightning optimizer are handled and weight updates are done automatically. To disable this behaviour, we have to set automatic_optimization=False and add following lines to handle backward computations manually:

    self.manual_backward(outer_loss, outer_optimizer)
    outer_optimizer.step()

References