Blackjax has an implicit dependence on "jax>=0.4.25" and "jaxlib>=0.4.25"
ColCarroll opened this issue · 2 comments
ColCarroll commented
Describe the issue as clearly as possible:
Noticed that jax.tree.<whatever>
is now being used here, which was introduced in jax[lib] 0.4.25.
Steps/code to reproduce the bug:
Need to install older version of JAX, as in
https://colab.research.google.com/drive/1K5-nH6NXJY-KggtSzbt0_kudUfR0d_2V?usp=sharing
then any code that hits a jax.tree...
call will throw
import jax
import blackjax
blackjax.optimizers.lbfgs.minimize_lbfgs(lambda x: x*x, -1.)
Expected result:
no exception
Error message:
AttributeError Traceback (most recent call last)
<ipython-input-3-6ef54887e135> in <cell line: 4>()
2 import blackjax
3
----> 4 blackjax.optimizers.lbfgs.minimize_lbfgs(lambda x: x*x, -1.)
2 frames
/usr/local/lib/python3.10/dist-packages/blackjax/optimizers/lbfgs.py in minimize_lbfgs(fun, x0, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs)
106
107 # Run LBFGS optimizer on flat input.
--> 108 last_step_raveled, history_raveled = _minimize_lbfgs(
109 lambda x: fun(unravel_fn(x)),
110 x0_raveled,
/usr/local/lib/python3.10/dist-packages/blackjax/optimizers/lbfgs.py in _minimize_lbfgs(fun, x0, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs)
231 )
232 # Append initial state to history.
--> 233 history = jax.tree.map(
234 lambda x, y: jnp.concatenate([x[None, ...], y], axis=0),
235 initial_history,
/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py in getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
54
55 return getattr
AttributeError: module 'jax' has no attribute 'tree'
### Blackjax/JAX/jaxlib/Python version information:
```python
blackjax nightly
jax, jaxlib 0.4.24
Context for the issue:
No response
junpenglao commented
I guess we can be explicit and update the dependence requirement?
ColCarroll commented
Thanks -- I'd been trying to figure out when to update bayeux
to jax.tree.foo
, and then realized i probably already have to!