[Question] Making the Learning Rate Learnable for the Implicit Gradient Component
XindiWu opened this issue · 7 comments
Required prerequisites
- I have read the documentation https://torchopt.readthedocs.io.
- I have searched the Issue Tracker and Discussions that this hasn't already been reported. (+1 or comment there if it has.)
- Consider asking first in a Discussion.
Questions
Hi @XuehaiPan @Benjamin-eecs @JieRen98,
I have been recently working with your codebase, and I've been trying to modify the implicit gradient component such that the learning rate becomes learnable as well.
As it stands now, the learning rate is a fixed parameter that we need to set before training. I have been trying to implement this change, but I have not been successful so far. Any guidance would be really helpful! Thank you very much for maintaining this project!
Hi @XindiWu, thanks for your attention and the question.
As it stands now, the learning rate is a fixed parameter that we need to set before training.
I'm afraid that I do not quite understand your question. Do you mean the inner loop learning rate or the outer loop learning rate?
If you mean the inner loop learning rate, you can define a lr
tensor as part of the meta-parameters. For example:
def stationary(params, meta_params, data):
lr, *other_meta_params = meta_params
# Stationary condition construction
return stationary condition # should be in the same structure as `params`, e.g., tuple of tensors
@torchopt.diff.implicit.custom_root(stationary)
def solve(params, meta_params, data):
lr, *other_meta_params = meta_params
# Solve the inner loop optimization
# find a solution for `stationary(params, meta_params, data) == zeros`
return solution # should be in the same structure as `params`, e.g., tuple of tensors
lr_tensor = torch.tensor(learning_rate, requires_grad=True)
meta_params = (lr_tensor, *other_meta_params)
solution = solve(init_params, meta_params, data)
outer_loss = loss_fn(solution, meta_params, data)
outer_loss.backward()
lr_tensor.grad
You need to construct a stationary
function that
For OOP API:
class InnerNet(torchopt.nn.ImplicitMetaGradientModule):
def __init__(self, lr, meta_param):
super().__init__()
self.lr = lr
self.meta_param = meta_param
self.net = torch.nn.Sequential(
...
)
def forward(self, batch):
# Forward process
...
def optimality(self, batch, labels):
# Define the optimality condition
return optimality # a tuple of tensors same as `tuple(self.parameters())`
def objective(self, batch, labels):
# Define the inner-loop optimization objective
return inner_loss # a scalar tensor
def solve(self, batch, labels):
# Conduct the inner-loop optimization
params = tuple(self.parameters())
inner_optim = torchopt.SGD(params, lr=self.lr) # `self.lr` is a tensor with `requires_grad=True`
with torch.enable_grad():
# Temporarily enable gradient computation for conducting the optimization
for _ in range(self.n_inner_iter):
inner_loss = self.objective(batch, labels)
inner_optim.zero_grad()
# NOTE: The parameter inputs should be explicitly specified in `backward` function
# as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into
# all the leaf Tensors (including the meta-parameters) that were used to compute the
# objective output. Alternatively, please use `torch.autograd.grad` instead.
inner_loss.backward(inputs=params) # backward pass in inner-loop
inner_optim.step() # update inner parameters
return self
lr_tensor = torch.tensor(learning_rate, requires_grad=True)
inner_net = InnerNet(lr_tensor, meta_param)
inner_net.solve() # solve inner loop with differentiable learning rate
outer_loss = loss_fn(inner_net, meta_params, lr_tensor, ...)
outer_loss.backward() # backward pass in outer-loop
lr_tensor.grad # gradient of learning rate
You can check out our tutorial notebooks for more guidance.
If you mean the outer loop learning rate, that will not affect the inner loop's optimization process.
Thank you a lot for the reply! I wonder if there should be any changes in the @torchopt.diff.implicit.custom_root?
I wonder if there should be any changes in the
@torchopt.diff.implicit.custom_root
?
@XindiWu No. The only thing your need to make is a stationary
function that
If you want to use other linear system solvers, you can pass the solve
argument to custom_root
:
def stationary_fn(params, meta_params, data):
...
@torchopt.diff.implicit.custom_root(stationary_fn, solve=torchopt.linear_solve.solve_inv()) # use neuman-series based solver
def solve_inner_loop_fn(params, meta_params, data):
...
return solution
You may also pass argnum
for better code organization:
def stationary_fn(params, alpha, beta, lr, batch, labels):
...
# argnum=(1, 3) means to calculate the gradient for the `solution` with respect to `alpha` and `lr` in the argument list
@torchopt.diff.implicit.custom_root(stationary_fn, argnum=(1, 3))
def solve_inner_loop_fn(params, alpha, beta, lr, batch, labels):
...
return solution
Check out our documentation Implicit Gradient Differentiation and API reference for more details.
Thank you!
I also just checked the doc and see the similar example as you gave before:
def solve(self, batch, labels):
parameters = tuple(self.parameters())
optimizer = torch.optim.Adam(parameters, lr=1e-3)
with torch.enable_grad():
for _ in range(100):
loss = self.objective(batch, labels)
optimizer.zero_grad()
# Only update the `.grad` attribute for parameters
# and leave the meta-parameters unchanged
loss.backward(inputs=parameters)
optimizer.step()
return self
However, in the colab notebook, the example is:
def inner_solver(params, meta_params, data):
# Initial functional optimizer based on TorchOpt
x, y, fmodel = data
optimizer = torchopt.sgd(lr=2e-2)
opt_state = optimizer.init(params)
with torch.enable_grad():
# Temporarily enable gradient computation for conducting the optimization
for i in range(100):
pred = fmodel(params, x)
loss = F.mse_loss(pred, y) # compute loss
# Compute regularization loss
regularization_loss = 0.0
for p1, p2 in zip(params, meta_params):
regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
final_loss = loss + regularization_loss
grads = torch.autograd.grad(final_loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates
params = torchopt.apply_updates(params, updates, inplace=True)
optimal_params = params
return optimal_params
I wonder what are the differences between these two?
I wonder what are the differences between these two?
@XindiWu The first example comes from the docstring for the torchopt.nn.ImplicitMetaGradientModule.solve
method. It's for the OOP API of the Implicit Gradient. The second example is using the functional API of the Implicit Gradient.
- OOP API:
torchopt.nn.ImplicitMetaGradientModule
- functional API:
torchopt.diff.implicit.custom_root
Awesome thank you again for the reply!!
Closing now. Feel free to ask more questions or ask for a reopening.