Array slice indices are not compatible with jit
Opened this issue · 0 comments
rochisha0 commented
Issue
def isherm_jaxdia(matrix, tol=None):
if matrix.shape[0] != matrix.shape[1]:
return False
tol = tol or qutip.settings.core["atol"]
done = []
for offset, data in zip(matrix.offsets, matrix.data):
if offset in done:
continue
start = max(0, offset)
end = min(matrix.shape[1], matrix.shape[0] + offset)
if -offset not in matrix.offsets:
if not _is_zero(data[start:end], tol):
return False
else:
idx = matrix.offsets.index(-offset)
done.append(-offset)
st = max(0, -offset)
et = min(matrix.shape[1], matrix.shape[0] - offset)
if not _is_conj(data[start:end], matrix.data[idx, st:et], tol):
return False
return True
is_herm_jaxdia
is not compatible with jit
as it uses array slice indices.
Solution
jax.lax.dynamic_slice
to perform slicing within the JIT-compiled function. dynamic_slice
is designed to be compatible with JAX's JIT compilation.