meta_update with a single task and meta-loss calculated with current weight?
hwijeen opened this issue · 4 comments
Hi Kate, thanks for the Pytorch code of MAML!
I have two questions(in which I suspect a bug?) on your implementation.
Line 10, Algorithm2 from the original paper indicates that meta_update is performed using each D'_i. To do so with your code, I think the function meta_update need access to every task sampled, since each task contains its D'_i in your implementation.
Line 172 in 75907ac
However, it seems that you perform meta_update with a single task, resulting in using only one D'_i of a specific task.
Line 10 also states that meta-loss is calculated with adapted parameters.
Line 71 in 75907ac
You seem to have calculated meta-loss with self.net, which I think is "original parameters"(\theta_i) in stead of adapted parameters.
Am I missing something?
Thanks for your interest in my repo!
- The gradients are accumulated across tasks in this line: https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py#L164
- The line you linked to is a bit of a hack that uses a hook to replace the grad fields with the grads from the adapted parameters. See here for actual computation of meta-gradients: https://github.com/katerakelly/pytorch-maml/blob/master/src/inner_loop.py#L47
Thank you for the quick reply!
You've referenced https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py#L164, in which you accumulate gradients across task(using data D). This is related to line 6-7 in Algorithm 2 of the original paper.
However, my question was about using D' of each task to perform meta-update, which is in line 10!
I think your implementation use one single D', in stead of each D', when calculating "meta-loss" across tasks.
The line I referenced is accumulating meta-gradients.
In Algorithm 2 in the MAML paper, lines 5-8 are implemented in inner_loop.py
, and the gradients used in line 10 are computed there also. The actual update is applied in maml.py
after these gradients have been accumulated across tasks.