IvanYashchuk/jax-fenics-adjoint

JIT is breaking because of numpy operations

aravindshaj opened this issue · 2 comments

Hi @IvanYashchuk . Thank you for creating this amazing interface for JAX with FEniCS !! I am receiving the following error and was hoping you could help me out with it.

Error occurs when I run my code with @jit
Screen Shot 2021-07-15 at 4 29 43 AM

Please let me know if you require anymore information.

Hello, @aravindshaj! Unfortunately, jax.jit is not supported yet. Previously, to make it work XLA CustomCall needed to be used. Now it seems like it's not difficult to add support of arbitrary Python functions to work with jax.jit using jax.experimental.host_callback but I haven't implemented it.

Hi @IvanYashchuk . Thank you for the quick reply. Would jax.jit be supported on the jax-fenics package?

Thank you for the recommendation. I am not very familiar with jax and wasn't able to understand how to use jax.experimental.host_callback. Would you be able to show a very primitive code sample where I can integrate jax.jit?