A question about the gradients
yurunsheng1 opened this issue · 1 comments
Hi,
First thank you for providing us such a nice work!
But I meet a question and really need you help:
In your MeLU.py lines 71-79:
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
# local update
for i in range(self.weight_len):
if self.weight_name[i] in self.local_update_target_weight_name:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
else:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
self.model.load_state_dict(self.fast_weights)
query_set_y_pred = self.model(query_set_x)
I understand this is the standard MAML approach (inner loop).
However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.
Looking forward to your reply!
Hi,
First thank you for providing us such a nice work!But I meet a question and really need you help:
In your MeLU.py lines 71-79:
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) # local update for i in range(self.weight_len): if self.weight_name[i] in self.local_update_target_weight_name: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] else: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] self.model.load_state_dict(self.fast_weights) query_set_y_pred = self.model(query_set_x)
I understand this is the standard MAML approach (inner loop).
However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.
Looking forward to your reply!
I believe you are right and the original code is wrong.