kwargs requirement for jit (coefficient) functions
Opened this issue · 0 comments
flowerthrower commented
Minor thing, but if its not too much work I think it would be cool if one could define multiple jax functions (w different parameters) without the need for **kwargs
-- as already possible with (non-jit) regular python functions.
The following example does not work if we drop **kwargs
.
def sin(t, p):
return p[0] * jnp.sin(p[1] * t + p[2])
@jax.jit
def sin_x(t, p, **kwargs): return sin(t, p)
@jax.jit
def sin_y(t, q, **kwargs): return sin(t, q)
H = [[qt.sigmax(), sin_x],
[qt.sigmay(), sin_y]]
evo = qt.mesolve(H, qt.basis(2, 0), tlist=[0, 1],
args={'p': [ 1, 1, 0], 'q': [ 1, 1, 0]},
options={'method': 'diffrax'})