google-deepmind/distrax

Incompatibility with JAX 0.4.14

RodrigoAVargasHdz opened this issue · 5 comments

Hi,

I have to downgrade the version of jax to 0.4.13 as the current version (0.4.14) is incompatible with the current version of distrax.

I get the following error.


  File "//lib/python3.10/site-packages/distrax/_src/utils/jittable.py", line 36, in tree_flatten
    switch = list(map(_is_jax_data, leaves))
  File "//python3.10/site-packages/distrax/_src/utils/jittable.py", line 66, in _is_jax_data
    jax.xla.abstractify(x)
  File  #"//python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'

Cheers,

rdaems commented

This is already fixed in #243 but not released yet.
I don't have the error when I use the master branch.

Hello !

Do you have an idea about when the next version of distrax will be released ?
Indeed, various fixes related to the changes in the jax API are not included in 0.1.4.

Thank you very much in advance.

Bumping this thread, any idea of a release date for this fix?

I think we can close this as distrax release 0.1.5 comes with the fix for those incompatibilities !

Note: The release has been made on GitHub but not on Pypi yet...

I've just released a new version on PyPi (https://pypi.org/project/distrax/0.1.5/).
Thanks for bumping this thread!