dfm/extending-jax

Updates to support jax>0.4.14

Closed this issue · 0 comments

@dfm Firstly, thank you for the write-up which has been very helpful to read through!

I've noticed that the code will no longer run on jax versions > 0.4.14 (currently 0.4.19). Having worked through the issues (which are all fairly minor) I would be happy to submit a PR to update the repo.

Briefly, these consist of:

  • deprecated register_cpu_custom_call_target function
  • changes to the api and return values for the jax helper custom_call function
  • ShapedArray has moved to jax.core for import
  • I would also like to tag the jax version installed as the colab notebook does not currently run (and the jax api seems to be continuing to evolve)