Library tracer leak
NeilGirdhar opened this issue · 0 comments
NeilGirdhar commented
JaxOpt seems to leak one of its own tracers into a static value.
Here's the traceback:
Traceback (most recent call last):
File "/home/neil/src/efax/a.py", line 10, in <module>
print(y.to_nat())
^^^^^^^^^^
File "/home/neil/src/efax/efax/_src/mixins/exp_to_nat/exp_to_nat.py", line 45, in to_nat
return self.search_to_natural(self.minimizer.solve(self))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/efax/_src/mixins/exp_to_nat/jaxopt.py", line 23, in solve
results = solver.run(exp_to_nat.initial_search_parameters())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/base.py", line 359, in run
return run(init_params, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun
return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/implicit_diff.py", line 207, in solver_fun_flat
return solver_fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/base.py", line 301, in _run
state = self.init_state(init_params, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/broyden.py", line 250, in init_state
d_history=init_history(init_params, self.history_size),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/broyden.py", line 102, in init_history
return tree_map(fun, pytree)
^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jaxopt/_src/broyden.py", line 101, in <lambda>
fun = lambda leaf: jnp.zeros((history_size,) + leaf.shape, dtype=leaf.dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 3846, in zeros
shape = canonicalize_shape(shape)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/efax/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 100, in canonicalize_shape
return core.canonicalize_shape(shape, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1, 2).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function to_nat at /home/neil/src/efax/efax/_src/mixins/exp_to_nat/exp_to_nat.py:41 for jit.
This is reproducible by running the tests in this branch, or this simple program:
from typing import Any
import jax.numpy as jnp
from jax import jacrev
from tjax import JaxRealArray
from efax import BetaNP, Flattener
def tn(flattener: Flattener[Any], flattened: JaxRealArray) -> Any:
y = flattener.unflatten(flattened)
return y.to_nat()
x = BetaNP(jnp.asarray([[1.2, 2.3]]))
y = x.to_exp()
flattener, flattened = Flattener[Any].flatten(y)
hessian = jacrev(jacrev(tn, argnums=1), argnums=1)(flattener, flattened)
print(hessian)
I was able to get this working fine with Optimistix and my own Tjax, but I was hoping to see how JaxOpt fares.