metaopt/torchopt

Does explicit gradient support a self-defined autograd function

happpyosu opened this issue · 2 comments

Required prerequisites

Questions

Thank you for your project first. I am solving a bilevel optimization using torchopt, but in my problem I have to implement a autograd function in torch, here are my example code

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
import torchopt


class DebugLossFunction(Function):
    @staticmethod
    def forward(ctx, x, r):
        loss = torch.mean(x * r)
        ctx.save_for_backward(x, r)
        return loss

    @staticmethod
    def backward(ctx, grad):
        x, r = ctx.saved_tensors
        grad_x = torch.ones_like(x)
        grad_r = torch.ones_like(r)

        return grad * grad_x, grad * grad_r


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.x = None

    def set_initial_x(self, r):
        self.x = nn.Parameter(torch.tensor(r, dtype=torch.float32), requires_grad=True)

    def outer_loss(self):
        temp = 2 * self.x
        loss = torch.mean(temp)
        return loss

    def inner_loss(self, r):
        loss = DebugLossFunction.apply(self.x, r)
        return loss

initial_x = np.random.rand(100, 100)
initial_x = 2 * (initial_x - 0.5)

net = Net()
net.set_initial_x(initial_x)

r = nn.Parameter(torch.tensor(initial_x), requires_grad=True)

# High-level API
optim = torchopt.MetaAdam(net, lr=1.0)

for i in range(10):
    inner_loss = net.inner_loss(r)
    optim.step(inner_loss)


outer_loss = net.outer_loss()
outer_loss.backward()

print(f'x.grad = {r.grad!r}')

and the output gradient of r is None.

x.grad = None

I am still new to bilevel optimization, is there any mistakes in my code or torchopt just does not support a self-defined autograd function?

Greeting,

I believe this is not supported. Firstly, since the bilevel algorithms need the explicit formula of the higher-order gradient. Giving a customized first-order function is not sufficient to calculate the higher-order gradient. Secondly, torchopt does not support customized higher-order functions so far.

Best,
Jie

Thank you for your answer. I will try to implement my loss function directly in torch instead.