why need use hook?
Interesting6 opened this issue · 1 comments
Thanks for your helpful codes.
I want to change the origin net's weights using tasks-average gradients on every task's query set. then using opt.step to uodate it.
Inspired by the optim.GSD source code,
Can your code:
`
hooks = []
for (k,v) in self.net.named_parameters():
def get_closure():
key = k
def replace_grad(grad):
return gradients[key]
return replace_grad
hooks.append(v.register_hook(get_closure()))
Compute grads for current step, replace with summed gradients as defined by hook
self.opt.zero_grad()
loss.backward()
Update the net parameters with the accumulated gradient according to optimizer
self.opt.step()
Remove the hooks before next training phase
for h in hooks:
h.remove()
`
be replaced by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.grad.data = g.data self.opt.step() self.opt.zero_grad()
or just by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.data.add_(-meta_lr, g.data)
Thanks for your time!
I do not think that will work, as the gradients of intermediate variables are not retained for memory reasons. Hooks are the preferred way to read and write grads of variables. See this thread: https://discuss.pytorch.org/t/why-cant-i-see-grad-of-an-intermediate-variable/94