JIT is breaking because of numpy operations
aravindshaj opened this issue · 2 comments
Hi @IvanYashchuk . Thank you for creating this amazing interface for JAX with FEniCS !! I am receiving the following error and was hoping you could help me out with it.
Error occurs when I run my code with @jit
Please let me know if you require anymore information.
Hello, @aravindshaj! Unfortunately, jax.jit
is not supported yet. Previously, to make it work XLA CustomCall needed to be used. Now it seems like it's not difficult to add support of arbitrary Python functions to work with jax.jit
using jax.experimental.host_callback
but I haven't implemented it.
Hi @IvanYashchuk . Thank you for the quick reply. Would jax.jit be supported on the jax-fenics package?
Thank you for the recommendation. I am not very familiar with jax and wasn't able to understand how to use jax.experimental.host_callback. Would you be able to show a very primitive code sample where I can integrate jax.jit?