google/jaxopt

Library tracer leak

NeilGirdhar opened this issue · 0 comments

JaxOpt seems to leak one of its own tracers into a static value.

Here's the traceback:

Traceback (most recent call last):
  File "/home/neil/src/efax/a.py", line 10, in <module>
    print(y.to_nat())
          ^^^^^^^^^^
  File "/home/neil/src/efax/efax/_src/mixins/exp_to_nat/exp_to_nat.py", line 45, in to_nat
    return self.search_to_natural(self.minimizer.solve(self))
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/efax/_src/mixins/exp_to_nat/jaxopt.py", line 23, in solve
    results = solver.run(exp_to_nat.initial_search_parameters())
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/base.py", line 359, in run
    return run(init_params, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun
    return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/implicit_diff.py", line 207, in solver_fun_flat
    return solver_fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/base.py", line 301, in _run
    state = self.init_state(init_params, *args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/broyden.py", line 250, in init_state
    d_history=init_history(init_params, self.history_size),
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/broyden.py", line 102, in init_history
    return tree_map(fun, pytree)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/broyden.py", line 101, in <lambda>
    fun = lambda leaf: jnp.zeros((history_size,) + leaf.shape, dtype=leaf.dtype)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 3846, in zeros
    shape = canonicalize_shape(shape)
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 100, in canonicalize_shape
    return core.canonicalize_shape(shape, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1, 2).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function to_nat at /home/neil/src/efax/efax/_src/mixins/exp_to_nat/exp_to_nat.py:41 for jit. 

This is reproducible by running the tests in this branch, or this simple program:

from typing import Any

import jax.numpy as jnp
from jax import jacrev
from tjax import JaxRealArray

from efax import BetaNP, Flattener


def tn(flattener: Flattener[Any], flattened: JaxRealArray) -> Any:
    y = flattener.unflatten(flattened)
    return y.to_nat()

x = BetaNP(jnp.asarray([[1.2, 2.3]]))
y = x.to_exp()
flattener, flattened = Flattener[Any].flatten(y)
hessian = jacrev(jacrev(tn, argnums=1), argnums=1)(flattener, flattened)
print(hessian)

I was able to get this working fine with Optimistix and my own Tjax, but I was hoping to see how JaxOpt fares.