Lightweight flexible research-oriented package to compute hypergradients in PyTorch.
Given the following bi-level problem.
We call hypergradient the following quantity.
Where:
-
is called the
outer objective
(e.g. the validation loss). - is called the
fixed point map
(e.g. a gradient descent step or the state update function in a recurrent model) - finding the solution of the fixed point equation is referred to as the
inner problem
. This can be solved by repeatedly applying the fixed point map or using a different inner algorithm.
See this notebook, where we show how to compute the hypergradient to optimize the regularization parameters of a simple logistic regression model.
examples/iMAML.py
shows an implementation of the method described in the paper Meta-learning with implicit gradients. The code uses higher to get stateless version of torch nn.Module-s and torchmeta for meta-dataset loading and minibatching.
This notebook shows how to train a simple equilibrium network with "RNN-style" dynamics.
MORE EXAMPLES COMING SOON
Hypergadients are useful to perform
- gradient-based hyperparamter optimization
- meta-learning
- training models that use an internal state (some types of RNNs and GNNs, Deep Equilibrium Networks, ...)
Requires python 3 and PyTorch version >= 1.4.
git clone git@github.com:prolearner/hypertorch.git
cd hypertorch
pip install -e .
python setup.py install
would also work.
The main methods for computing hypergradients are in the module hypergrad/hypergradients.py
.
All methods require as input:
- a list of tensors representing the inner variables (models' weights);
- another list of tensors for the outer variables (hyperparameters/meta-learner paramters);
- a
callable
differentiable outer objective; - a
callable
that represents the differentiable update mapping (exceptreverse_unroll
). For example this can be an SGD step.
These methods differentiate through the update dynamics used to solve the inner problem. This allows to optimize the inner solver parameters such as the learning rate and momentum.
Methods in this class are:
reverse_unroll
: computes the approximate hypergradient by unrolling the entire computational graph of the update dynamics for solving the inner problem. The method is essentially a wrapper for standard backpropagation. IMPORTANT NOTE: the weights must be non-leaf tensors obtained through the application of "PyThorch differentiable" update dynamics (do not use built-in optimizers!). NOTE N2.: this method is memory hungry!reverse
: computes the hypergradient as above but uses less memory. It uses the trajectory information and recomputes all other necessary intermediate variables in the backward pass. It requires the list of past weights and the list ofcallable
update mappings applied during the inner optimization.
These methods approximate the hypergradient equation directly by:
- Using an approximate solution to the inner problem instead of the true one.
- Computing an approximate solution to the linear system
(I-J)x_star = b
, whereJ
andb
are respectively the transpose of the jacobian of the fixed point map and the gradient of the outer objective both w.r.t the inner variable and computed on the approximate solution to the inner problem.
Since computing and storing J
is usually unfeasible, these methods exploit torch.autograd
to compute the Jacobian-vector product Jx
efficiently. Additionally, they do not require storing the trajectory of the inner solver, thus providing a potentially large memory advantage over iterative differentiation. These methods are not suited to optimize the parameters of the inner solver like the learning rate.
Methods in this class are:
fixed_point
: it approximately solves the linear system by repeatedly applying the mapT(x) = Jx + b
. NOTE: this method converges only when the fixed point map and consequently the mapT
are contractions.CG
: it approximately solves the linear system with the conjugate gradient method. IMPORTANT N0TE:I-J
must be symmetric and positive definite for this to work!CG_normal_eq
: As above, but uses conjugate gradient on the normal equations (i.e. solvesJ^TJx = J^Tb
instead) which works also whenI-J
is not symmetric and positive definite. NOTE: the cost per iteration can be much higher than the other methods.stoch_AID
: General method that can use anytorch.optim.Optimizer
to solve the linear system. See Algorithm 1 in Grazzi et al. 2022 for more details.
If you use this code, please cite our paper
@inproceedings{grazzi2020iteration,
title={On the Iteration Complexity of Hypergradient Computation},
author={Grazzi, Riccardo and Franceschi, Luca and Pontil, Massimiliano and Salzo, Saverio},
journal={Thirty-seventh International Conference on Machine Learning (ICML)},
year={2020}
}