google-deepmind/dm-haiku

DeprecationWarning `transform_with_state` because of `jax.xla` vs `jax.interpreters.xla`

Opened this issue · 0 comments

Hi, when calling hk.transform_with_state internally there is an access to jax.xla, this is marked as deprecated in favor of jax.interpreters.xla. The problem is in checking if the provided function f to the haiku function is not jax transformed.

I.e., the misdoer is: check_not_jax_transformed (at least, that's how far I've looked into the code; there may exist more references).

Could this be updated? That will help silencing my test output :).