JIT compilation using Qobj
Opened this issue · 7 comments
The benefit of using JAX is the ability to JIT compile. With the setup right now, it is not clear what's the best way to make JAX recognize QuTiP objects as valid inputs since JIT only works for pure JAX arrays. There are workarounds to it, e.g., https://github.com/google/jax/blob/cc13fd1e5892a08f5360db933d4dfd64c0fc66eb/jax/experimental/lapax.py#L164. The alternative is to use the data._jxa instead of passing around quantum objects as:
def test_jit_from_jxa():
"""Test JIT of add using the _jxa array"""
@jax.jit
def func():
return sigmax().to("jax").data._jxa + sigmay().to("jax").data._jxa
assert isinstance(func(), jax.interpreters.xla.DeviceArray)
def test_jit_from_qobj():
"""Test JIT of add directly using Qobj"""
@jax.jit
def func():
return sigmax() + sigmay()
assert isinstance(func(), jax.interpreters.xla.DeviceArray)
Qobj
are designed to work mixing multiple data type mixed and it's not clear that jax
can find the mathematics operation in a dispatched function. The dispatcher itself is cython compiled.
We could try to register the Qobj as a pytree node when importing qutip-jax.
Also giving Qobj
a way to extract the specializations and skip the dispatcher will probably be required.
We should be able to set it to work with .data
.
Thank you @Ericgig and @quantshah for making this repo and pushing this forward!! This is amazing!
I just wanna add a small bit of my understanding of JAX JIT. It seems to me that JIT (and jax.grad etc.) works as long as the function-to-be-jitted only takes inputs and gives outputs that are supported by JAX, e.g jnp.arrays
, and the input-dependent
part only uses JAX supported operations. Any other complication in generating this function doesn't matter as long as their compiler can go through the Python script and translate it to XLA code at the run time. So in principle, the current implementation should be sufficient as long as we are a little bit careful with the function we jit.
The following JIT example works fine for me with the current master branch.
import qutip
import qutip_jax
import jax
@jax.jit
def fun(a):
M = qutip.sigmay().to("jax")
N = qutip.sigmaz().to("jax")
return (a * M.conj() * N + N).data._jxa
fun(3.)
This looks sufficient to me. The input is just numbers and the output is a JAX array.
Maybe instead of making the whole Qobj
compatible with JAX JIT, we just need to add an additional wrapper that transfer the final jax array to Qobj
?
We need to define what is sufficient.
Are Qobj
as output only, enough?
In the example, it fails if we don't manually convert to jax
. Do we want this to be automatic?
Do we want only operation to be supported or should it work with ptrace
or eigenenergies
also?
Are Qobj as output only, enough?
No. This is just an example showing that even if you cannot directly jit
a function that returns a Qobj
. You can still define a wrapper to make it compatible fairly easily. I'm not sure how difficult it is to represent the Qobj
class with a pytree. If it is too much work, maybe we can define a qutip_jax.jit
that automatically transfers all the Qobj
to JAX array and transfer the output, (if it is a JAX array or Pytrees of them) back to Qobj
?
In the way that JAX implemented things, you cannot jit any function but only those that are pure functions. It is very likely that we can never make jax.jit
work with qutip.mesolve
. We can only jit a customized integrator and return the result in Qobj
.
In the example, it fails if we don't manually convert to jax. Do we want this to be automatic?
Yes, that would be great. It should be feasible with some global settings? E.g. with a default dtype
.
Do we want only operation to be supported or should it work with ptrace or eigenenergies also?
Yes it should also work with other Qobj
operations like tensor
and ptrace
. But it should be just the same as adding the specialisations like adding and multiplication, no? In principle, most of them should work by replacing np
with jnp
. The only caveat is that jnp.array
does not support NumPy in-place assignment. One has to use their own syntax.
tensor
will be just adding a new specialisations, but ptrace
will not. You cannot branch on input value with jit
per default, so ptrace
's sel
will cause issues. eigenstates
return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...
solver
and integrator
cannot be jitted
, nor it makes any sense to try to. But we need to think about getting grad
working with solvers.