/pytorch-leo

Pytorch Implemtation of Meta-Learning with Latent Embedding Optimization

Primary LanguagePythonMIT LicenseMIT

Pytorch-LEO: A Pytorch Implemtation of Meta-Learning with Latent Embedding Optimization(LEO)

Running the code

Prerequisites

  • torch==1.4.0
  • PyYAML==3.13

Getting the data

We borrow the embedding from the deepmind/leo repo
You can download the pretrained embeddings here,
or do

$ wget http://storage.googleapis.com/leo-embeddings/embeddings.zip
$ unzip embeddings.zip

Run Training

python3 main.py -train \ 
                -verbose \ 
                -N 5 \ 
                -K 1 \ 
                -embedding_dir $(EMBEDDING_DIR) \ 
                -dataset miniImageNet \ 
                -exp_name toy-example \ 
                -save_checkpoint

where

  • -N, -K means N-way K-shot training,
  • -exp_name help you keep track of your experiment,
  • -save_checkpoint to save model for later testing.

for full arguments, see main.py

Run Testing

python3 main.py -test \
		-N 5 \
		-K 1 \
		-embedding_dir $(EMBEDDING_DIR) \
		-dataset miniImageNet \
		-verbose \
    		-load $(model_path) 

The testing result will be printed on the console.

Monitor Training

This projects comes with Comet.ml support. If you want to disable logging, just add -disable_comet as an argument.
You will need to modify the COMET_PROJECT_NAME and COMET_WORKSPACE in config.yml to enable monitoring.

*If you do not save your comet API key in .comet.config, you will have to specify API key in line 147 in solver.py.

Hyperparameters

You can modify the hyperparameters in config.yml, where detailed descriptions are also provided.
The hyperparameters that yield the best result in this code are as follow:

Hyperparameters miniImageNet 1-shot miniImageNet 5-shot tieredImageNet 1-shot tieredImageNet 5-shot
outer_lr 0.0005 0.0006 0.0006 0.0006
l2_penalty_weight 0.0001 8.5e-6 3.6e-10 3.6e-10
orthogonality_penalty_weight 303.0 0.00152 0.188 0.188
dropout 0.3 0.3 0.3 0.3
kl_weight 0 0.001 0.001 0.001
encoder_penalty_weight 1e-9 2.66e-7 5.7e-6 5.7e-6

Result

Implementation miniImageNet 1-shot miniImageNet 5-shot tieredImageNet 1-shot tieredImageNet 5-shot
LEO Paper 61.76 ± 0.08% 77.59 ± 0.12% 66.33 ± 0.05% 81.44 ± 0.09%
this code 59.46 ± 0.08% 76.01 ± 0.09% 66.62 ± 0.07% 81.72 ± 0.09%
*The result we obtained may not be comparable since the model is trained on both the training set and validation set in the paper, while our model is only trained on the training set and validated on the validation set.

Note: This project is licensed under the terms of the MIT license.