qutip/qutip-jax

kwargs requirement for jit (coefficient) functions

Opened this issue · 0 comments

Minor thing, but if its not too much work I think it would be cool if one could define multiple jax functions (w different parameters) without the need for **kwargs -- as already possible with (non-jit) regular python functions.
The following example does not work if we drop **kwargs.

def sin(t, p):
    return p[0] * jnp.sin(p[1] * t + p[2])

@jax.jit
def sin_x(t, p, **kwargs): return sin(t, p)

@jax.jit
def sin_y(t, q, **kwargs): return sin(t, q)

H = [[qt.sigmax(), sin_x],
     [qt.sigmay(), sin_y]]

evo = qt.mesolve(H, qt.basis(2, 0), tlist=[0, 1], 
                 args={'p': [ 1, 1, 0], 'q': [ 1, 1, 0]}, 
                 options={'method': 'diffrax'})