katerakelly/pytorch-maml

why need use hook?

Interesting6 opened this issue · 1 comments

Thanks for your helpful codes.

I want to change the origin net's weights using tasks-average gradients on every task's query set. then using opt.step to uodate it.
Inspired by the optim.GSD source code,
Can your code:
`
hooks = []
for (k,v) in self.net.named_parameters():
def get_closure():
key = k
def replace_grad(grad):
return gradients[key]
return replace_grad
hooks.append(v.register_hook(get_closure()))

Compute grads for current step, replace with summed gradients as defined by hook

self.opt.zero_grad()
loss.backward()

Update the net parameters with the accumulated gradient according to optimizer

self.opt.step()

Remove the hooks before next training phase

for h in hooks:
h.remove()
`

be replaced by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.grad.data = g.data self.opt.step() self.opt.zero_grad()

or just by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.data.add_(-meta_lr, g.data)

Thanks for your time!

I do not think that will work, as the gradients of intermediate variables are not retained for memory reasons. Hooks are the preferred way to read and write grads of variables. See this thread: https://discuss.pytorch.org/t/why-cant-i-see-grad-of-an-intermediate-variable/94