
meta_update with a single task and meta-loss calculated with current weight?

hwijeen opened this issue · 4 comments

Hi Kate, thanks for the Pytorch code of MAML!

I have two questions(in which I suspect a bug?) on your implementation.

Line 10, Algorithm2 from the original paper indicates that meta_update is performed using each D'_i. To do so with your code, I think the function meta_update need access to every task sampled, since each task contains its D'_i in your implementation.

self.meta_update(task, grads)

However, it seems that you perform meta_update with a single task, resulting in using only one D'_i of a specific task.

Line 10 also states that meta-loss is calculated with adapted parameters.

loss, out = forward_pass(self.net, in_, target)

You seem to have calculated meta-loss with self.net, which I think is "original parameters"(\theta_i) in stead of adapted parameters.

Am I missing something?

Thanks for your interest in my repo!

  1. The gradients are accumulated across tasks in this line: https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py#L164
  2. The line you linked to is a bit of a hack that uses a hook to replace the grad fields with the grads from the adapted parameters. See here for actual computation of meta-gradients: https://github.com/katerakelly/pytorch-maml/blob/master/src/inner_loop.py#L47

Thank you for the quick reply!

You've referenced https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py#L164, in which you accumulate gradients across task(using data D). This is related to line 6-7 in Algorithm 2 of the original paper.
However, my question was about using D' of each task to perform meta-update, which is in line 10!
I think your implementation use one single D', in stead of each D', when calculating "meta-loss" across tasks.

The line I referenced is accumulating meta-gradients.
In Algorithm 2 in the MAML paper, lines 5-8 are implemented in inner_loop.py, and the gradients used in line 10 are computed there also. The actual update is applied in maml.py after these gradients have been accumulated across tasks.