MAML-Pytorch

This is a PyTorch implementation of Model-Agnostic Meta-Learning (MAML).

For Official Tensorflow Implementation, please visit HERE.

For a good tutorial of meta learning (with tensorflow), see HERE.

requirements

  1. python: 3.x
  2. pytorch 1.0+

files & dirs

  1. encoder.py

    模型定义文件,需要手动定义支持参数传递的forward()函数,更换模型时这个过程得重新进行一遍,此外须手动设置BN的running statistics,代码迁移性差

  2. encoder_general.py

    相对普适的模型定义文件,迁移性稍好,更换模型时不用从零开始定义支持参数传递的forward()函数,目前仅适于VGG类串行网络结构

  3. MAML_1st.py

    简洁的MAML一阶近似版实现(基于pytorch内建函数)

  4. MAML_2nd.py

    相对普适的MAML二阶版实现(基于手动更新网络参数); 通过设置参数,可退化为MAML_1st.py中的一阶版近似;

    本文最初目的是实现普适版的MAML二阶更新代码,普适是指对于任意模型结构,在不用修改原模型定义文件,或者仅需要添加一个类函数的前提下,即可方便地将MAML二阶更新用于该模型;

    经过若干尝试,目前仅实现对于VGG类串行网络结构的普适代码,对于ResNet这类含分支的网络结构,需要手动重构其build block [一个不太简洁的示例],使支持带参数的前向传播 (pytorch框架可能不支持本文对模型普适的需求,欢迎讨论交流xiongkai4925@cvte.com / bearkai1992@qq.com

  5. reptile.py

    简洁的reptile实现,reptile的伪代码如下:

     Init param W
     for iteration 1,2,3,... do
         Randomly sample a task T
         Perform k>1 steps of SGD on different minibatches of task T, starting with W, resulting in W1
         Update: W := W + lr*(W1-W)
     end for
     Return W
    
  6. test_grad_20200408.py

    编写的一些tiny_test,帮助理解pytorch的backward()grad(), .data和.detach(), retain_graphcreate_graph等概念,帮助验证MAML一阶/二阶梯度更新中的一些操作是否符合预期。

    若只需要实现MAML梯度更新,可以忽略该文件。

  7. test_param_20191016.md

    编写的一些tiny_test,帮助理解pytorch的参数更新机制,并排查实现深度随机策略时的bug。

    若只需要实现MAML梯度更新,可以忽略该文件。对pytorch框架感兴趣的同学,建议阅读文件最后的总结内容。

  8. discard_scripts

    一些失败的尝试,涉及hook_function, name_modules()等概念;若只需要实现MAML梯度更新,可以忽略该文件。

Note

  1. 上述实现普适版MAML二阶,更多是为了研究上的完备性,顺便加深对深度框架的理解;
  2. 实际上,对于小模型,参考encoder.py重新定义模型的forward(),工作量较小,对于大模型,比如ResNet50,需要重构其build block,使能进行指定参数的前向传播,此外大模型可能不太用到二阶梯度更新,因为Hessian矩阵太占内存,计算速度太慢。

To do

Reptile [1, 2] and training tricks of MAML [3, 4]