PennyLaneAI/catalyst

Support for `qml.StatePrep` on `lightning.qubit` with Catalyst

joeybarreto opened this issue · 7 comments

In the code below, I create a test circuit and then compute its gradient in two ways. The first way just applies jax.jit(jax.grad(...)) to a partial completion of the qnode after an initial state has been supplied. The second uses the Catalyst @qjit decorator and grad function instead. jgrad succeeds, but jgrad_qjit fails with the error DifferentiableCompileError: StatePrep is non-differentiable on 'lightning.qubit' device. I'm not sure whether this is a bug or expected behavior, but if qml.StatePrep works on the lightning backend without using Catalyst, I'm not sure why it would fail here. How hard would it be to add support for arbitrary state prep when using Catalyst?

Nq = 2
init_state = np.array([1,0,0,0])

def test(angles, init_state):
    qml.StatePrep(init_state, wires=range(Nq))
    qml.RY(angles[0], wires=0)
    qml.RY(angles[1], wires=1)
    return qml.expval(qml.PauliZ(0))

qnode_test = qml.QNode(test, 
          qml.device('lightning.qubit', wires=Nq), 
          interface='jax',
          diff_method='best')
qnode_test = partial(qnode_test, init_state=init_state)

@qjit
def jgrad_qjit(angles):
    g = grad(qnode_test)
    return g(angles)

jgrad = jax.jit(jax.grad(qnode_test))

angles = jnp.array([0.1,0.2])

qnode_test(angles, init_state=init_state) #<-- succeeds
jgrad(angles) #<-- succeeds
jgrad_qjit(angles) #<-- fails

The full error message is

---------------------------------------------------------------------------
DifferentiableCompileError                Traceback (most recent call last)
Cell In[548], line 1
----> 1 jgrad_qjit(angles)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:457, in QJIT.__call__(self, *args, **kwargs)
    453         kwargs = {"static_argnums": self.compile_options.static_argnums, **kwargs}
    455     return self.user_function(*args, **kwargs)
--> 457 requires_promotion = self.jit_compile(args, **kwargs)
    459 # If we receive tracers as input, dispatch to the JAX integration.
    460 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:528, in QJIT.jit_compile(self, args, **kwargs)
    524 # Capture with the patched conversion rules
    525 with Patcher(
    526     (ag_primitives, "module_allowlist", self.patched_module_allowlist),
    527 ):
--> 528     self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
    529         args, **kwargs
    530     )
    532 self.mlir_module, self.mlir = self.generate_ir()
    533 self.compiled_function, self.qir = self.compile()

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
    140 @functools.wraps(fn)
    141 def wrapper(*args, **kwargs):
    142     if not InstrumentSession.active:
--> 143         return fn(*args, **kwargs)
    145     with ResultReporter(stage_name, has_finegrained) as reporter:
    146         fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:610, in QJIT.capture(self, args, **kwargs)
    607         _inject_transform_named_sequence()
    608         return self.user_function(*args, **kwargs)
--> 610     jaxpr, out_type, treedef = trace_to_jaxpr(
    611         fn_with_transform_named_sequence, static_argnums, abstracted_axes, full_sig, kwargs
    612     )
    614 return jaxpr, out_type, treedef, dynamic_sig

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:536, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs)
    531     with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
    532         make_jaxpr_kwargs = {
    533             "static_argnums": static_argnums,
    534             "abstracted_axes": abstracted_axes,
    535         }
--> 536         jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
    538 return jaxpr, out_type, out_treedef

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py:542, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    540     f, out_tree_promise = flatten_fun(f, in_tree)
    541     f = annotate(f, in_type)
--> 542     jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
    543 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
    544 return closed_jaxpr, out_type, out_tree_promise()

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/profiler.py:335, in annotate_function.<locals>.wrapper(*args, **kwargs)
    332 @wraps(func)
    333 def wrapper(*args, **kwargs):
    334   with TraceAnnotation(name, **decorator_kwargs):
--> 335     return func(*args, **kwargs)
    336   return wrapper

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2362, in trace_to_jaxpr_dynamic2(fun, debug_info)
   2360 with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
   2361   main.jaxpr_stack = ()  # type: ignore
-> 2362   jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2363   del main, fun
   2364 return jaxpr, out_type, consts

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2377, in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2375 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
   2376 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2377 ans = fun.call_wrapped(*in_tracers_)
   2378 out_tracers = map(trace.full_raise, ans)
   2379 jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py:192, in WrappedFun.call_wrapped(self, *args, **kwargs)
    189 gen = gen_static_args = out_store = None
    191 try:
--> 192   ans = self.f(*args, **dict(self.params, **kwargs))
    193 except:
    194   # Some transformations yield from inside context managers, so we have to
    195   # interrupt them before reraising the exception. Otherwise they will only
    196   # get garbage-collected at some later time, running their cleanup tasks
    197   # only after this exception is handled, which can corrupt the global
    198   # state.
    199   while stack:

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:608, in QJIT.capture.<locals>.fn_with_transform_named_sequence(*args, **kwargs)
    596 """
    597 This function behaves exactly like the user function being jitted,
    598 taking in the same arguments and producing the same results, except
   (...)
    605 jaxpr. It is never executed or used anywhere, except being traced here.
    606 """
    607 _inject_transform_named_sequence()
--> 608 return self.user_function(*args, **kwargs)

Cell In[545], line 19
     16 @qjit
     17 def jgrad_qjit(angles):
     18     g = grad(qnode_test)
---> 19     return g(angles)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/api_extensions/differentiation.py:663, in Grad.__call__(self, *args, **kwargs)
    653 grad_params = _check_grad_params(
    654     self.grad_params.method,
    655     self.grad_params.scalar_out,
   (...)
    660     self.grad_params.with_value,
    661 )
    662 input_data_flat, _ = tree_flatten((args, kwargs))
--> 663 jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *args, **kwargs)
    664 if self.grad_params.with_value:  # use value_and_grad
    665     args_argnum = tuple(args[i] for i in grad_params.argnums)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/api_extensions/differentiation.py:802, in _make_jaxpr_check_differentiable(f, grad_params, *args, **kwargs)
    800 method = grad_params.method
    801 with mark_gradient_tracing(method):
--> 802     jaxpr, shape = jax.make_jaxpr(f, return_shape=True)(*args, **kwargs)
    803 _, out_tree = tree_flatten(shape)
    805 for pos, arg in enumerate(jaxpr.in_avals):

    [... skipping hidden 6 frame]

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:172, in Function.__call__(self, *args, **kwargs)
    170 @debug_logger
    171 def __call__(self, *args, **kwargs):
--> 172     jaxpr, _, out_tree = make_jaxpr2(self.fn)(*args, **kwargs)
    174     def _eval_jaxpr(*args, **kwargs):
    175         return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py:542, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    540     f, out_tree_promise = flatten_fun(f, in_tree)
    541     f = annotate(f, in_type)
--> 542     jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
    543 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
    544 return closed_jaxpr, out_type, out_tree_promise()

    [... skipping hidden 4 frame]

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:588, in QJIT.capture.<locals>.closure(qnode, *args, **kwargs)
    586 params["static_argnums"] = kwargs.pop("static_argnums", static_argnums)
    587 params["_out_tree_expected"] = []
--> 588 return QFunc.__call__(qnode, *args, **dict(params, **kwargs))

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py:163, in QFunc.__call__(self, *args, **kwargs)
    161 dynamic_args = filter_static_args(args, static_argnums)
    162 args_flat = tree_flatten((dynamic_args, kwargs))[0]
--> 163 res_flat = func_p.bind(flattened_fun, *args_flat, fn=self)
    164 return tree_unflatten(out_tree_promise(), res_flat)[0]

    [... skipping hidden 4 frame]

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py:141, in QFunc.__call__.<locals>._eval_quantum(*args, **kwargs)
    140 def _eval_quantum(*args, **kwargs):
--> 141     closed_jaxpr, out_type, out_tree, out_tree_exp = trace_quantum_function(
    142         self.func,
    143         qjit_device,
    144         args,
    145         kwargs,
    146         self,
    147         static_argnums,
    148     )
    150     out_tree_expected.append(out_tree_exp)
    151     dynamic_args = filter_static_args(args, static_argnums)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:1178, in trace_quantum_function(f, device, args, kwargs, qnode, static_argnums)
   1174         device_modify_measurements = False  # this is only for the new API transform program
   1176     qnode_program = qnode.transform_program if qnode else TransformProgram()
-> 1178     tapes, post_processing = apply_transform(
   1179         qnode_program,
   1180         device_program,
   1181         device_modify_measurements,
   1182         quantum_tape,
   1183         return_values_flat,
   1184     )
   1186 # (2) - Quantum tracing
   1187 transformed_results = []

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:991, in apply_transform(qnode_program, device_program, device_modify_measurements, tape, flat_results)
    988     # Apply the identity transform in order to keep generalization
    989     total_program = device_program
--> 991 tapes, post_processing = total_program([tape])
    992 if not is_valid_for_batch and len(tapes) > 1:
    993     msg = "Multiple tapes are generated, but each run might produce different results."

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:515, in TransformProgram.__call__(self, tapes)
    513 if self._argnums is not None and self._argnums[i] is not None:
    514     tape.trainable_params = self._argnums[i][j]
--> 515 new_tapes, fn = transform(tape, *targs, **tkwargs)
    516 execution_tapes.extend(new_tapes)
    518 fns.append(fn)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:242, in verify_operations(tape, grad_method, qjit_device)
    238             _paramshift_op_checker(op)
    240     return (in_inverse, in_control)
--> 242 _verify_nested(tape, (False, False), _op_checker)
    244 return (tape,), lambda x: x[0]

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:60, in _verify_nested(tape, state, op_checker_fn)
     58 ctx = EvaluationContext.get_main_tracing_context()
     59 for op in tape.operations:
---> 60     inner_state = op_checker_fn(op, state)
     61     if has_nested_tapes(op):
     62         for region in nested_quantum_regions(op):

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:236, in verify_operations.<locals>._op_checker(op, state)
    234 _mcm_op_checker(op)
    235 if grad_method == "adjoint":
--> 236     _adj_diff_op_checker(op)
    237 elif grad_method == "parameter-shift":
    238     _paramshift_op_checker(op)

File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:154, in verify_operations.<locals>._adj_diff_op_checker(op)
    150     op_name = op.name
    151 if not qjit_device.qjit_capabilities.native_ops.get(
    152     op_name, EMPTY_PROPERTIES
    153 ).differentiable:
--> 154     raise DifferentiableCompileError(
    155         f"{op.name} is non-differentiable on '{qjit_device.original_device.name}' device"
    156     )

DifferentiableCompileError: StatePrep is non-differentiable on 'lightning.qubit' device

qml.about():

Name: PennyLane
Version: 0.38.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/joey/miniconda3/envs/pennylane/lib/python3.12/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.4.1-arm64-arm-64bit
Python version:          3.12.4
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- nvidia.custatevec (PennyLane-Catalyst-0.8.1)
- nvidia.cutensornet (PennyLane-Catalyst-0.8.1)
- oqc.cloud (PennyLane-Catalyst-0.8.1)
- softwareq.qpp (PennyLane-Catalyst-0.8.1)
- lightning.qubit (PennyLane_Lightning-0.38.0)
- default.clifford (PennyLane-0.38.0)
- default.gaussian (PennyLane-0.38.0)
- default.mixed (PennyLane-0.38.0)
- default.qubit (PennyLane-0.38.0)
- default.qubit.autograd (PennyLane-0.38.0)
- default.qubit.jax (PennyLane-0.38.0)
- default.qubit.legacy (PennyLane-0.38.0)
- default.qubit.tf (PennyLane-0.38.0)
- default.qubit.torch (PennyLane-0.38.0)
- default.qutrit (PennyLane-0.38.0)
- default.qutrit.mixed (PennyLane-0.38.0)
- default.tensor (PennyLane-0.38.0)
- null.qubit (PennyLane-0.38.0)

Hi @joeybarreto, unfortunately this is a known bug in Catalyst at the moment; see #1065 for more details.

In short, this is because without @qjit, PennyLane is correctly identifying operations which need to be decomposed for gradient computation, but is not doing the same when @qjit is enabled.

Instead, when @qjit is enabled, Catalyst is verifying that all operators are differentiable, even when certain operations do not need to be differentiated.

We are working to fix this, for now, here is a workaround:

Nq = 2
init_state = np.array([1, 0, 0, 0])

dev = qml.device('lightning.qubit', wires=Nq)

@qml.qnode(dev)
def qnode_test(angles, init_state):
    qml.BasisState.compute_decomposition(init_state, wires=range(Nq))
    qml.RY(angles[0], wires=0)
    qml.RY(angles[1], wires=1)
    return qml.expval(qml.PauliZ(0))

@qjit
def jgrad_qjit(angles):
    g = grad(qnode_test, argnums=0)
    return g(angles, init_state)
>>> angles = jnp.array([0.1,0.2])
>>> jax.grad(qnode_test)(angles, init_state=init_state)
Array([0.09983342, 0.        ], dtype=float64)
>>> jgrad_qjit(angles)
Array([ 9.98334166e-02, -5.55111512e-17], dtype=float64)

Note that I have swapped StatePrep for BasisState, which has a more efficient decomposition for basis states.

Thanks, yeah I had seen #1065 but wasn't 100% sure if this was the same issue. In reality my initial states are pretty complicated and not just something like [1 0 0 0] (which I just put as a dummy value in this example), and I expect that trying to decompose them numerically will be way too costly since I'm looking at circuits with 14-20+ qubits. I am working on finding explicit, analytic gate decompositions but until then, I am stuck with qml.StatePrep.

I do have a follow-up question though, regarding the execution time of the two compiled functions above (whether or not qml.StatePrep is used). I modified the circuit to be a bit more substantial, like below:

Nq = 18
def test(angles):
    for kk in range(5):
        for ii in range(Nq):
            qml.RY(angles[ii], wires=ii)
        for ii in range(0,Nq,2):
            qml.CNOT(wires=[ii % Nq, (ii+1) % Nq])
    return qml.expval(qml.PauliZ(0))

qnode_test = qml.QNode(test, 
          qml.device('lightning.qubit', wires=Nq), 
          interface='jax',
          diff_method='best')

@qjit
def jgrad_qjit(angles):
    g = grad(qnode_test)
    return g(angles)

jgrad = jax.jit(jax.grad(qnode_test))

angles = np.array([0.1]*Nq)

After compilation, one call to jgrad takes 89ms, while jgrad_qjit takes 2.5s per call (at 18 qubits). I tested this over different numbers of qubits (with constant circuit depth), and jgrad_qjit is consistently slower than just vanilla JAX by more than 10x beyond Nq=14. Is this expected? I can open a separate issue if the answer is not trivial. See below for the scaling comparison

Nqs = [2,4,6,8,10,12,14,16,18,20,22]
ts1 = [894e-6, 1.5e-3, 2.15e-3, 2.9e-3, 3.65e-3, 5.06e-3, 8.8e-3, 23.9e-3, 89.3e-3, 446e-3, 1.86]
ts2 = [3.02e-3, 3.51e-3, 4.07e-3, 5.13e-3, 8.47e-3, 23.3e-3, 104e-3, 499e-3, 2.51, 13.2, 73]

image

In reality my initial states are pretty complicated and not just something like [1 0 0 0] (which I just put as a dummy value in this example), and I expect that trying to decompose them numerically will be way too costly since I'm looking at circuits with 14-20+ qubits. I am working on finding explicit, analytic gate decompositions but until then, I am stuck with qml.StatePrep.

Ah I see! You could try this approach:

Nq = 2
init_state = np.array([1, 0, 0, 0])

dev = qml.device('lightning.qubit', wires=Nq)

@qml.qnode(dev)
def qnode_test(angles):
    qml.StatePrep.compute_decomposition(init_state, wires=range(Nq))
    qml.RY(angles[0], wires=0)
    qml.RY(angles[1], wires=1)
    return qml.expval(qml.PauliZ(0))

@qjit
def jgrad_qjit(angles):
    g = grad(qnode_test, argnums=0)
    return g(angles)

Note a couple of things here:

  • I am using np.array to set the initial state, not jnp.array.
  • I am not passing the initial state as a function argument

Both of these are important, as it means that the state decomposition will happen in Python -- JAX will not try to compile it, which we have noticed is significantly costly.

Thank you for the suggestion, I just tried that on my minimal 14 qubit example (which involves a lot of project code not shown above), after ~5 minutes compiling the gradient function has yet to finish. (Also, I make sure to only pass in initial states as numpy arrays and to provide them during partial completions before creating my qnodes, so I don't expect that JAX is tracing them).

However, the benchmarking plot I shared above suggests a deeper issue. Note that in that example, I do not use any initial state prep, I am just comparing two different ways of jitting the gradient of the qnode. I find that Catalyst (via qjit) is much slower than direct JAX usage, which suggests that trying to get qml.StatePrep to work in Catalyst is a non-issue if Catalyst is substantially slower than the alternative even before any kind of state preparation gets involved. I don't mind directly using JAX, but it seems Catalyst is meant to be better suited to Pennylane and/or more convenient due to Autograph. Yet, Catalyst gradients settle to being around 1.5 OOM slower than direct JAX gradients, so I'm wondering if I'm misunderstanding its usage or if it is indeed better to use JAX directly on circuits with many qubits.

Hi @joeybarreto, thanks for sharing that benchmark! I think the difference comes from an unfortunate default value for the diff_method="best" parameter, which in the case of Catalyst would use the paremeter-shift method. Using diff_method="adjoint" should show comparable results. On my machine I'm now seeing 116 ms for catalyst.

Thanks for pointing that out @dime10 ! Where is the default value specified? It seems like grad has a default value of method=None, but if it were auto, then I would expect it to use adjoint since using the call print(qnode_test.best_method_str(dev, qnode_test.interface)) described here, I get a printed value of adjoint when diff_method = best and when I create my qnode via

qnode_test = qml.QNode(test, 
          qml.device('lightning.qubit', wires=Nq), 
          interface='jax',
          diff_method='best')

Sorry, default may not have been the best term, what I meant is the value determined to be used when "best" is chosen by the user. In PennyLane, this will be adjoint for the lightning device, but not in Catalyst. The value is chosen here in Catalyst, and for compatibility reasons it was set to parameter-shift early on (that is, there is no "smart", device-aware choice that is implemented yet):

diff_method = "parameter-shift" if fn.diff_method == "best" else str(fn.diff_method)

We'll be sure to update this soon for better default performance.