DeprecationWarning `transform_with_state` because of `jax.xla` vs `jax.interpreters.xla`
joeryjoery opened this issue · 0 comments
joeryjoery commented
Hi, when calling hk.transform_with_state
internally there is an access to jax.xla
, this is marked as deprecated in favor of jax.interpreters.xla
. The problem is in checking if the provided function f
to the haiku
function is not jax
transformed.
I.e., the misdoer is: check_not_jax_transformed
(at least, that's how far I've looked into the code; there may exist more references).
Could this be updated? That will help silencing my test output :).