qutip/qutip-jax

Use `jax.lax`?

Opened this issue · 1 comments

@quantshah
I have been using the jax.numpy interface to quickly add specialisation. But some of our function are not optimal using this interface: .dag() calls transpose and conj looping 2 times. The scale in add is added in and extra loop etc. In unary operations I jitted functions that had loops that could be fused. Should I try to use jax.lax instead? Or should I be more aggressive with jitting function? add and matmul could also benefit from it because of the scale entry.

I think we can be a bit more aggressive at this point with JIT and get back to jax.lax in a later iteration but keep track of where we can directly use jax.lax. I think for JITing of conditionals or for loops later we can directly use lax but at this point its probably better to keep everything within jnp.