metaopt/torchopt

[Question] Making the Learning Rate Learnable for the Implicit Gradient Component

XindiWu opened this issue · 7 comments

Required prerequisites

Questions

Hi @XuehaiPan @Benjamin-eecs @JieRen98,

I have been recently working with your codebase, and I've been trying to modify the implicit gradient component such that the learning rate becomes learnable as well.

As it stands now, the learning rate is a fixed parameter that we need to set before training. I have been trying to implement this change, but I have not been successful so far. Any guidance would be really helpful! Thank you very much for maintaining this project!

Hi @XindiWu, thanks for your attention and the question.

As it stands now, the learning rate is a fixed parameter that we need to set before training.

I'm afraid that I do not quite understand your question. Do you mean the inner loop learning rate or the outer loop learning rate?

If you mean the inner loop learning rate, you can define a lr tensor as part of the meta-parameters. For example:

def stationary(params, meta_params, data):
    lr, *other_meta_params = meta_params
    # Stationary condition construction
    return stationary condition  # should be in the same structure as `params`, e.g., tuple of tensors

@torchopt.diff.implicit.custom_root(stationary)
def solve(params, meta_params, data):
    lr, *other_meta_params = meta_params
    # Solve the inner loop optimization
    # find a solution for `stationary(params, meta_params, data) == zeros`
    return solution  # should be in the same structure as `params`, e.g., tuple of tensors


lr_tensor = torch.tensor(learning_rate, requires_grad=True)
meta_params = (lr_tensor, *other_meta_params)
solution = solve(init_params, meta_params, data)
outer_loss = loss_fn(solution, meta_params, data)
outer_loss.backward()
lr_tensor.grad

You need to construct a stationary function that

$$ \frac{\partial \text{stationary}}{\partial \text{lr}} \ne 0 $$

For OOP API:

class InnerNet(torchopt.nn.ImplicitMetaGradientModule):
    def __init__(self, lr, meta_param):
        super().__init__()
        self.lr = lr
        self.meta_param = meta_param
        self.net = torch.nn.Sequential(
            ...
        )

    def forward(self, batch):
        # Forward process
        ...

    def optimality(self, batch, labels):
        # Define the optimality condition
        return optimality  # a tuple of tensors same as `tuple(self.parameters())`

    def objective(self, batch, labels):
        # Define the inner-loop optimization objective
        return inner_loss  # a scalar tensor

    def solve(self, batch, labels):
        # Conduct the inner-loop optimization
        params = tuple(self.parameters())
        inner_optim = torchopt.SGD(params, lr=self.lr)  # `self.lr` is a tensor with `requires_grad=True`
        with torch.enable_grad():
            # Temporarily enable gradient computation for conducting the optimization
            for _ in range(self.n_inner_iter):
                inner_loss = self.objective(batch, labels)
                inner_optim.zero_grad()
                # NOTE: The parameter inputs should be explicitly specified in `backward` function
                # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into
                # all the leaf Tensors (including the meta-parameters) that were used to compute the
                # objective output. Alternatively, please use `torch.autograd.grad` instead.
                inner_loss.backward(inputs=params)  # backward pass in inner-loop
                inner_optim.step()  # update inner parameters
        return self


lr_tensor = torch.tensor(learning_rate, requires_grad=True)
inner_net = InnerNet(lr_tensor, meta_param)
inner_net.solve()  # solve inner loop with differentiable learning rate
outer_loss = loss_fn(inner_net, meta_params, lr_tensor, ...)
outer_loss.backward()  # backward pass in outer-loop
lr_tensor.grad  # gradient of learning rate

You can check out our tutorial notebooks for more guidance.


If you mean the outer loop learning rate, that will not affect the inner loop's optimization process.

Thank you a lot for the reply! I wonder if there should be any changes in the @torchopt.diff.implicit.custom_root?

I wonder if there should be any changes in the @torchopt.diff.implicit.custom_root?

@XindiWu No. The only thing your need to make is a stationary function that

$$ \frac{\partial \text{stationary}}{\partial \text{lr}} \ne 0 $$

If you want to use other linear system solvers, you can pass the solve argument to custom_root:

def stationary_fn(params, meta_params, data):
    ...

@torchopt.diff.implicit.custom_root(stationary_fn, solve=torchopt.linear_solve.solve_inv())  # use neuman-series based solver
def solve_inner_loop_fn(params, meta_params, data):
    ...
    return solution

You may also pass argnum for better code organization:

def stationary_fn(params, alpha, beta, lr, batch, labels):
    ...

# argnum=(1, 3) means to calculate the gradient for the `solution` with respect to `alpha` and `lr` in the argument list
@torchopt.diff.implicit.custom_root(stationary_fn, argnum=(1, 3))
def solve_inner_loop_fn(params, alpha, beta, lr, batch, labels):
    ...
    return solution

Check out our documentation Implicit Gradient Differentiation and API reference for more details.

Thank you!

I also just checked the doc and see the similar example as you gave before:

def solve(self, batch, labels):
    parameters = tuple(self.parameters())
    optimizer = torch.optim.Adam(parameters, lr=1e-3)
    with torch.enable_grad():
        for _ in range(100):
            loss = self.objective(batch, labels)
            optimizer.zero_grad()
            # Only update the `.grad` attribute for parameters
            # and leave the meta-parameters unchanged
            loss.backward(inputs=parameters)
            optimizer.step()
    return self

However, in the colab notebook, the example is:

def inner_solver(params, meta_params, data):
    # Initial functional optimizer based on TorchOpt
    x, y, fmodel = data
    optimizer = torchopt.sgd(lr=2e-2)
    opt_state = optimizer.init(params)
    with torch.enable_grad():
        # Temporarily enable gradient computation for conducting the optimization
        for i in range(100):
            pred = fmodel(params, x)
            loss = F.mse_loss(pred, y)  # compute loss

            # Compute regularization loss
            regularization_loss = 0.0
            for p1, p2 in zip(params, meta_params):
                regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
            final_loss = loss + regularization_loss

            grads = torch.autograd.grad(final_loss, params)  # compute gradients
            updates, opt_state = optimizer.update(grads, opt_state, inplace=True)  # get updates
            params = torchopt.apply_updates(params, updates, inplace=True)

    optimal_params = params
    return optimal_params

I wonder what are the differences between these two?

I wonder what are the differences between these two?

@XindiWu The first example comes from the docstring for the torchopt.nn.ImplicitMetaGradientModule.solve method. It's for the OOP API of the Implicit Gradient. The second example is using the functional API of the Implicit Gradient.

Awesome thank you again for the reply!!

Closing now. Feel free to ask more questions or ask for a reopening.