Use `jax.lax`?
Opened this issue · 1 comments
@quantshah
I have been using the jax.numpy
interface to quickly add specialisation. But some of our function are not optimal using this interface: .dag()
calls transpose
and conj
looping 2 times. The scale in add is added in and extra loop etc. In unary operations I jitted functions that had loops that could be fused. Should I try to use jax.lax
instead? Or should I be more aggressive with jitting function? add
and matmul
could also benefit from it because of the scale entry.
I think we can be a bit more aggressive at this point with JIT and get back to jax.lax in a later iteration but keep track of where we can directly use jax.lax. I think for JITing of conditionals or for loops later we can directly use lax but at this point its probably better to keep everything within jnp.