qutip/qutip-jax

JIT compilation using Qobj

Opened this issue · 7 comments

The benefit of using JAX is the ability to JIT compile. With the setup right now, it is not clear what's the best way to make JAX recognize QuTiP objects as valid inputs since JIT only works for pure JAX arrays. There are workarounds to it, e.g., https://github.com/google/jax/blob/cc13fd1e5892a08f5360db933d4dfd64c0fc66eb/jax/experimental/lapax.py#L164. The alternative is to use the data._jxa instead of passing around quantum objects as:

def test_jit_from_jxa():
    """Test JIT of add using the _jxa array"""

    @jax.jit
    def func():
        return sigmax().to("jax").data._jxa + sigmay().to("jax").data._jxa

    assert isinstance(func(), jax.interpreters.xla.DeviceArray)

def test_jit_from_qobj():
    """Test JIT of add directly using Qobj"""

    @jax.jit
    def func():
        return sigmax() + sigmay()

    assert isinstance(func(), jax.interpreters.xla.DeviceArray)

Qobj are designed to work mixing multiple data type mixed and it's not clear that jax can find the mathematics operation in a dispatched function. The dispatcher itself is cython compiled.

We could try to register the Qobj as a pytree node when importing qutip-jax.
Also giving Qobj a way to extract the specializations and skip the dispatcher will probably be required.

We should be able to set it to work with .data.

Thank you @Ericgig and @quantshah for making this repo and pushing this forward!! This is amazing!

I just wanna add a small bit of my understanding of JAX JIT. It seems to me that JIT (and jax.grad etc.) works as long as the function-to-be-jitted only takes inputs and gives outputs that are supported by JAX, e.g jnp.arrays, and the input-dependent part only uses JAX supported operations. Any other complication in generating this function doesn't matter as long as their compiler can go through the Python script and translate it to XLA code at the run time. So in principle, the current implementation should be sufficient as long as we are a little bit careful with the function we jit.

The following JIT example works fine for me with the current master branch.

import qutip
import qutip_jax
import jax

@jax.jit
def fun(a):
    M = qutip.sigmay().to("jax")
    N = qutip.sigmaz().to("jax")
    return (a * M.conj() * N + N).data._jxa

fun(3.)

This looks sufficient to me. The input is just numbers and the output is a JAX array.

Maybe instead of making the whole Qobj compatible with JAX JIT, we just need to add an additional wrapper that transfer the final jax array to Qobj?

We need to define what is sufficient.

Are Qobj as output only, enough?
In the example, it fails if we don't manually convert to jax. Do we want this to be automatic?
Do we want only operation to be supported or should it work with ptrace or eigenenergies also?

Are Qobj as output only, enough?

No. This is just an example showing that even if you cannot directly jit a function that returns a Qobj. You can still define a wrapper to make it compatible fairly easily. I'm not sure how difficult it is to represent the Qobj class with a pytree. If it is too much work, maybe we can define a qutip_jax.jit that automatically transfers all the Qobj to JAX array and transfer the output, (if it is a JAX array or Pytrees of them) back to Qobj?

In the way that JAX implemented things, you cannot jit any function but only those that are pure functions. It is very likely that we can never make jax.jit work with qutip.mesolve. We can only jit a customized integrator and return the result in Qobj.

In the example, it fails if we don't manually convert to jax. Do we want this to be automatic?

Yes, that would be great. It should be feasible with some global settings? E.g. with a default dtype.

Do we want only operation to be supported or should it work with ptrace or eigenenergies also?

Yes it should also work with other Qobj operations like tensor and ptrace. But it should be just the same as adding the specialisations like adding and multiplication, no? In principle, most of them should work by replacing np with jnp. The only caveat is that jnp.array does not support NumPy in-place assignment. One has to use their own syntax.

tensor will be just adding a new specialisations, but ptrace will not. You cannot branch on input value with jit per default, so ptrace's sel will cause issues. eigenstates return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...

solver and integrator cannot be jitted, nor it makes any sense to try to. But we need to think about getting grad working with solvers.