qutip/qutip-jax

Array slice indices are not compatible with jit

Opened this issue · 0 comments

Issue

def isherm_jaxdia(matrix, tol=None):
    if matrix.shape[0] != matrix.shape[1]:
        return False
    tol = tol or qutip.settings.core["atol"]
    done = []
    for offset, data in zip(matrix.offsets, matrix.data):
        if offset in done:
            continue
        start = max(0, offset)
        end = min(matrix.shape[1], matrix.shape[0] + offset)
        if -offset not in matrix.offsets:
            if not _is_zero(data[start:end], tol):
                return False
        else:
            idx = matrix.offsets.index(-offset)
            done.append(-offset)
            st = max(0, -offset)
            et = min(matrix.shape[1], matrix.shape[0] - offset)
            if not _is_conj(data[start:end], matrix.data[idx, st:et], tol):
                return False
    return True

is_herm_jaxdia is not compatible with jit as it uses array slice indices.

Solution
jax.lax.dynamic_slice to perform slicing within the JIT-compiled function. dynamic_slice is designed to be compatible with JAX's JIT compilation.