This repository is the reimplementation of MAML (Model-Agnostic Meta-Learning) algorithm. Differentiable optimizers are handled by Higher library and NN-template is used for structuring the project. The default settings are used for training on Omniglot (5-way 5-shot) problem. It can be easily extended for other few-shot datasets thanks to Torchmeta library.
On Local Machine
- Download and install dependencies
git clone https://github.com/rcmalli/lightning-maml.git
cd ./lightning-maml/
pip install -r requirements.txt
- Create
.env
file containing the info given below using your own Wandb. ai account to track experiments. You can use.env.template
file.
export DATASET_PATH="/your/project/root/data/"
export WANDB_ENTITY="USERNAME"
export WANDB_API_KEY="KEY"
- Run the experiment
python3 src/run.py train.pl_trainer.gpus=1
On Google Colab
Few-shot learning using this dataset is easy task to overfit or learn for MAML algorithm.
Metatrain | Metavalidation | |||||
---|---|---|---|---|---|---|
Algorithm | Model | inner_steps | inner accuracy | outer accuracy | inner accuracy | outer accuracy |
MAML | OmniConv | 1 | 0.992 | 0.992 | 0.98 | 0.98 |
MAML | OmniConv | 5 | 1.0 | 1.0 | 1.0 | 1.0 |
Inside 'conf' folder, you can change all the settings depending on your problem or dataset. The default parameters are set for Omniglot dataset. Here are some examples for customization:
python3 src/run.py train.pl_trainer.gpus=0 train.pl_trainer.fast_dev_run=true
python3 src/run.py train.pl_trainer.gpus=1 train.pl_trainer.max_epochs=1000 \
data.datamodule.num_inner_steps=5
python3 src/run.py train.pl_trainer.gpus=1 data.datamodule.num_inner_steps=5,10,20 -m
If you want to try a different dataset (ex. MiniImageNet), you can copy
default.yaml file inside conf/data
to miniimagenet.yaml
and edit these
lines :
datamodule:
_target_: pl.datamodule.MetaDataModule
datasets:
train:
_target_: torchmeta.datasets.MiniImagenet
root: ${env:DATASET_PATH}
meta_train: True
download: True
val:
_target_: torchmeta.datasets.MiniImagenet
root: ${env:DATASET_PATH}
meta_val: True
download: True
test:
_target_: torchmeta.datasets.MiniImagenet
root: ${env:DATASET_PATH}
meta_test: True
download: True
# you may need to update data augmentation and preprocessing steps also!!!
Run the experiment as follows:
python3 src/run.py data=miniimagenet
If you plant to implement a new variant of MAML algorithm (for example MAML++) you can start by extending default lightning module and its step function.
There are few required modifications run meta-learning algorithm using pytorch-lightning as high-level library
-
In supervised learning we have
M
mini-batches for each epoch. However, we haveN
tasks for single meta-batch in meta learning settings. We have to set our dataloader length to 1 otherwise, the dataloader will indefinitely sample from the dataset. -
Apart from traditional test phase of supervised learning, we need gradient computation also in test phase. Currently, pytorch-lightning does not allow you to enable gradient computation by settings, you have to add single line to your beginning of test and validation steps as following:
torch.set_grad_enabled(True)
-
In MAML algorithm, we have two different optimizers to train our model. Inner optimizer must be differentiable and outer optimizer should update model using updated weights inside inner iteration from support set and updates from query set. In Pytorch-lightning optimizer are handled and weight updates are done automatically. To disable this behaviour, we have to set
automatic_optimization=False
and add following lines to handle backward computations manually:self.manual_backward(outer_loss, outer_optimizer) outer_optimizer.step()