Incompatibility with JAX 0.4.14
RodrigoAVargasHdz opened this issue · 5 comments
Hi,
I have to downgrade the version of jax to 0.4.13 as the current version (0.4.14) is incompatible with the current version of distrax.
I get the following error.
File "//lib/python3.10/site-packages/distrax/_src/utils/jittable.py", line 36, in tree_flatten
switch = list(map(_is_jax_data, leaves))
File "//python3.10/site-packages/distrax/_src/utils/jittable.py", line 66, in _is_jax_data
jax.xla.abstractify(x)
File #"//python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'
Cheers,
This is already fixed in #243 but not released yet.
I don't have the error when I use the master branch.
Hello !
Do you have an idea about when the next version of distrax
will be released ?
Indeed, various fixes related to the changes in the jax API are not included in 0.1.4
.
Thank you very much in advance.
Bumping this thread, any idea of a release date for this fix?
I think we can close this as distrax release 0.1.5 comes with the fix for those incompatibilities !
Note: The release has been made on GitHub but not on Pypi yet...
I've just released a new version on PyPi (https://pypi.org/project/distrax/0.1.5/).
Thanks for bumping this thread!