hoyeoplee/MeLU

A question about the gradients

yurunsheng1 opened this issue · 1 comments

Hi,
First thank you for providing us such a nice work!

But I meet a question and really need you help:

In your MeLU.py lines 71-79:

grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
            # local update
            for i in range(self.weight_len):
                if self.weight_name[i] in self.local_update_target_weight_name:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
                else:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
        self.model.load_state_dict(self.fast_weights)
        query_set_y_pred = self.model(query_set_x)

I understand this is the standard MAML approach (inner loop).

However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.

Looking forward to your reply!

Hi,
First thank you for providing us such a nice work!

But I meet a question and really need you help:

In your MeLU.py lines 71-79:

grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
            # local update
            for i in range(self.weight_len):
                if self.weight_name[i] in self.local_update_target_weight_name:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
                else:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
        self.model.load_state_dict(self.fast_weights)
        query_set_y_pred = self.model(query_set_x)

I understand this is the standard MAML approach (inner loop).

However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.

Looking forward to your reply!

I believe you are right and the original code is wrong.