Change to jax.interpreters.xla for JAX==0.4.14
kmheckel opened this issue · 3 comments
Hi,
Just opening this issue to raise awareness for changes that will be needed for JAX 0.4.14. I would be happy to try and make some of the updates but this would be the first time I've ever tried submitting a pull request. Thanks!
For JAX 0.4.14 The following APIs have been removed after previous deprecation:
jax.ad: use jax.interpreters.ad.
jax.curry: use curry = lambda f: partial(partial, f).
jax.partial_eval: use jax.interpreters.partial_eval.
jax.pxla: use jax.interpreters.pxla.
jax.xla: use jax.interpreters.xla.
jax.ShapedArray: use jax.core.ShapedArray.
jax.interpreters.pxla.device_put: use jax.device_put().
jax.interpreters.pxla.make_sharded_device_array: use jax.make_array_from_single_device_arrays().
jax.interpreters.pxla.ShardedDeviceArray: use jax.Array.
jax.numpy.DeviceArray: use jax.Array.
Hi @kmheckel , thank you for letting us know. I believe that at HEAD we have made all of these required changes.
In general since Haiku is developed inside Google's monorepo (and mirrored live to GitHub) these sort of renames are typically handled automatically for us.
Are you having issues with Haiku + JAX 0.4.14? Can you paste the error message if so?
@tomhennigan Sorry about the delay, dm-haiku 0.0.10 fixed my issue. When I had tried updating haiku right before hitting the error I had it was because PyPI only had 0.0.9 which didn't have the fix.
Thanks!
Thanks for letting us know 😄