aleximmer/Laplace

`torch.func` jacobian OOM

wiseodd opened this issue · 2 comments

Minimum example to get this error:
https://gist.github.com/wiseodd/b29973cd96f96f3af620ca131571eaa4

Fix:

use AsdlGGN backend and then in _glm_predictive_distribution (in baselaplace.py), use:

Js, f_mu = self.backend.jacobians(X)

instead of:

Js, f_mu = self.backend.functorch_jacobians(X)

I think this is because: pytorch/functorch#1058

Users will be warned for now via #202. Ultimately, this shall be solved by #203.