Does explicit gradient support a self-defined autograd function
happpyosu opened this issue · 2 comments
Required prerequisites
- I have read the documentation https://torchopt.readthedocs.io.
- I have searched the Issue Tracker and Discussions that this hasn't already been reported. (+1 or comment there if it has.)
- Consider asking first in a Discussion.
Questions
Thank you for your project first. I am solving a bilevel optimization using torchopt, but in my problem I have to implement a autograd function in torch, here are my example code
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
import torchopt
class DebugLossFunction(Function):
@staticmethod
def forward(ctx, x, r):
loss = torch.mean(x * r)
ctx.save_for_backward(x, r)
return loss
@staticmethod
def backward(ctx, grad):
x, r = ctx.saved_tensors
grad_x = torch.ones_like(x)
grad_r = torch.ones_like(r)
return grad * grad_x, grad * grad_r
class Net(nn.Module):
def __init__(self):
super().__init__()
self.x = None
def set_initial_x(self, r):
self.x = nn.Parameter(torch.tensor(r, dtype=torch.float32), requires_grad=True)
def outer_loss(self):
temp = 2 * self.x
loss = torch.mean(temp)
return loss
def inner_loss(self, r):
loss = DebugLossFunction.apply(self.x, r)
return loss
initial_x = np.random.rand(100, 100)
initial_x = 2 * (initial_x - 0.5)
net = Net()
net.set_initial_x(initial_x)
r = nn.Parameter(torch.tensor(initial_x), requires_grad=True)
# High-level API
optim = torchopt.MetaAdam(net, lr=1.0)
for i in range(10):
inner_loss = net.inner_loss(r)
optim.step(inner_loss)
outer_loss = net.outer_loss()
outer_loss.backward()
print(f'x.grad = {r.grad!r}')
and the output gradient of r is None.
x.grad = None
I am still new to bilevel optimization, is there any mistakes in my code or torchopt just does not support a self-defined autograd function?
Greeting,
I believe this is not supported. Firstly, since the bilevel algorithms need the explicit formula of the higher-order gradient. Giving a customized first-order function is not sufficient to calculate the higher-order gradient. Secondly, torchopt does not support customized higher-order functions so far.
Best,
Jie
Thank you for your answer. I will try to implement my loss function directly in torch instead.