Support for `qml.StatePrep` on `lightning.qubit` with Catalyst
joeybarreto opened this issue · 7 comments
In the code below, I create a test
circuit and then compute its gradient in two ways. The first way just applies jax.jit(jax.grad(...))
to a partial completion of the qnode after an initial state has been supplied. The second uses the Catalyst @qjit
decorator and grad
function instead. jgrad
succeeds, but jgrad_qjit
fails with the error DifferentiableCompileError: StatePrep is non-differentiable on 'lightning.qubit' device
. I'm not sure whether this is a bug or expected behavior, but if qml.StatePrep
works on the lightning backend without using Catalyst, I'm not sure why it would fail here. How hard would it be to add support for arbitrary state prep when using Catalyst?
Nq = 2
init_state = np.array([1,0,0,0])
def test(angles, init_state):
qml.StatePrep(init_state, wires=range(Nq))
qml.RY(angles[0], wires=0)
qml.RY(angles[1], wires=1)
return qml.expval(qml.PauliZ(0))
qnode_test = qml.QNode(test,
qml.device('lightning.qubit', wires=Nq),
interface='jax',
diff_method='best')
qnode_test = partial(qnode_test, init_state=init_state)
@qjit
def jgrad_qjit(angles):
g = grad(qnode_test)
return g(angles)
jgrad = jax.jit(jax.grad(qnode_test))
angles = jnp.array([0.1,0.2])
qnode_test(angles, init_state=init_state) #<-- succeeds
jgrad(angles) #<-- succeeds
jgrad_qjit(angles) #<-- fails
The full error message is
---------------------------------------------------------------------------
DifferentiableCompileError Traceback (most recent call last)
Cell In[548], line 1
----> 1 jgrad_qjit(angles)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:457, in QJIT.__call__(self, *args, **kwargs)
453 kwargs = {"static_argnums": self.compile_options.static_argnums, **kwargs}
455 return self.user_function(*args, **kwargs)
--> 457 requires_promotion = self.jit_compile(args, **kwargs)
459 # If we receive tracers as input, dispatch to the JAX integration.
460 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:528, in QJIT.jit_compile(self, args, **kwargs)
524 # Capture with the patched conversion rules
525 with Patcher(
526 (ag_primitives, "module_allowlist", self.patched_module_allowlist),
527 ):
--> 528 self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
529 args, **kwargs
530 )
532 self.mlir_module, self.mlir = self.generate_ir()
533 self.compiled_function, self.qir = self.compile()
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
140 @functools.wraps(fn)
141 def wrapper(*args, **kwargs):
142 if not InstrumentSession.active:
--> 143 return fn(*args, **kwargs)
145 with ResultReporter(stage_name, has_finegrained) as reporter:
146 fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:610, in QJIT.capture(self, args, **kwargs)
607 _inject_transform_named_sequence()
608 return self.user_function(*args, **kwargs)
--> 610 jaxpr, out_type, treedef = trace_to_jaxpr(
611 fn_with_transform_named_sequence, static_argnums, abstracted_axes, full_sig, kwargs
612 )
614 return jaxpr, out_type, treedef, dynamic_sig
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:536, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs)
531 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
532 make_jaxpr_kwargs = {
533 "static_argnums": static_argnums,
534 "abstracted_axes": abstracted_axes,
535 }
--> 536 jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
538 return jaxpr, out_type, out_treedef
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py:542, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
540 f, out_tree_promise = flatten_fun(f, in_tree)
541 f = annotate(f, in_type)
--> 542 jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
543 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
544 return closed_jaxpr, out_type, out_tree_promise()
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/profiler.py:335, in annotate_function.<locals>.wrapper(*args, **kwargs)
332 @wraps(func)
333 def wrapper(*args, **kwargs):
334 with TraceAnnotation(name, **decorator_kwargs):
--> 335 return func(*args, **kwargs)
336 return wrapper
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2362, in trace_to_jaxpr_dynamic2(fun, debug_info)
2360 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
2361 main.jaxpr_stack = () # type: ignore
-> 2362 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2363 del main, fun
2364 return jaxpr, out_type, consts
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2377, in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2375 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
2376 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2377 ans = fun.call_wrapped(*in_tracers_)
2378 out_tracers = map(trace.full_raise, ans)
2379 jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py:192, in WrappedFun.call_wrapped(self, *args, **kwargs)
189 gen = gen_static_args = out_store = None
191 try:
--> 192 ans = self.f(*args, **dict(self.params, **kwargs))
193 except:
194 # Some transformations yield from inside context managers, so we have to
195 # interrupt them before reraising the exception. Otherwise they will only
196 # get garbage-collected at some later time, running their cleanup tasks
197 # only after this exception is handled, which can corrupt the global
198 # state.
199 while stack:
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:608, in QJIT.capture.<locals>.fn_with_transform_named_sequence(*args, **kwargs)
596 """
597 This function behaves exactly like the user function being jitted,
598 taking in the same arguments and producing the same results, except
(...)
605 jaxpr. It is never executed or used anywhere, except being traced here.
606 """
607 _inject_transform_named_sequence()
--> 608 return self.user_function(*args, **kwargs)
Cell In[545], line 19
16 @qjit
17 def jgrad_qjit(angles):
18 g = grad(qnode_test)
---> 19 return g(angles)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/api_extensions/differentiation.py:663, in Grad.__call__(self, *args, **kwargs)
653 grad_params = _check_grad_params(
654 self.grad_params.method,
655 self.grad_params.scalar_out,
(...)
660 self.grad_params.with_value,
661 )
662 input_data_flat, _ = tree_flatten((args, kwargs))
--> 663 jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *args, **kwargs)
664 if self.grad_params.with_value: # use value_and_grad
665 args_argnum = tuple(args[i] for i in grad_params.argnums)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/api_extensions/differentiation.py:802, in _make_jaxpr_check_differentiable(f, grad_params, *args, **kwargs)
800 method = grad_params.method
801 with mark_gradient_tracing(method):
--> 802 jaxpr, shape = jax.make_jaxpr(f, return_shape=True)(*args, **kwargs)
803 _, out_tree = tree_flatten(shape)
805 for pos, arg in enumerate(jaxpr.in_avals):
[... skipping hidden 6 frame]
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:172, in Function.__call__(self, *args, **kwargs)
170 @debug_logger
171 def __call__(self, *args, **kwargs):
--> 172 jaxpr, _, out_tree = make_jaxpr2(self.fn)(*args, **kwargs)
174 def _eval_jaxpr(*args, **kwargs):
175 return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py:542, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
540 f, out_tree_promise = flatten_fun(f, in_tree)
541 f = annotate(f, in_type)
--> 542 jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
543 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
544 return closed_jaxpr, out_type, out_tree_promise()
[... skipping hidden 4 frame]
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py:588, in QJIT.capture.<locals>.closure(qnode, *args, **kwargs)
586 params["static_argnums"] = kwargs.pop("static_argnums", static_argnums)
587 params["_out_tree_expected"] = []
--> 588 return QFunc.__call__(qnode, *args, **dict(params, **kwargs))
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py:163, in QFunc.__call__(self, *args, **kwargs)
161 dynamic_args = filter_static_args(args, static_argnums)
162 args_flat = tree_flatten((dynamic_args, kwargs))[0]
--> 163 res_flat = func_p.bind(flattened_fun, *args_flat, fn=self)
164 return tree_unflatten(out_tree_promise(), res_flat)[0]
[... skipping hidden 4 frame]
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py:141, in QFunc.__call__.<locals>._eval_quantum(*args, **kwargs)
140 def _eval_quantum(*args, **kwargs):
--> 141 closed_jaxpr, out_type, out_tree, out_tree_exp = trace_quantum_function(
142 self.func,
143 qjit_device,
144 args,
145 kwargs,
146 self,
147 static_argnums,
148 )
150 out_tree_expected.append(out_tree_exp)
151 dynamic_args = filter_static_args(args, static_argnums)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:1178, in trace_quantum_function(f, device, args, kwargs, qnode, static_argnums)
1174 device_modify_measurements = False # this is only for the new API transform program
1176 qnode_program = qnode.transform_program if qnode else TransformProgram()
-> 1178 tapes, post_processing = apply_transform(
1179 qnode_program,
1180 device_program,
1181 device_modify_measurements,
1182 quantum_tape,
1183 return_values_flat,
1184 )
1186 # (2) - Quantum tracing
1187 transformed_results = []
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py:991, in apply_transform(qnode_program, device_program, device_modify_measurements, tape, flat_results)
988 # Apply the identity transform in order to keep generalization
989 total_program = device_program
--> 991 tapes, post_processing = total_program([tape])
992 if not is_valid_for_batch and len(tapes) > 1:
993 msg = "Multiple tapes are generated, but each run might produce different results."
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:515, in TransformProgram.__call__(self, tapes)
513 if self._argnums is not None and self._argnums[i] is not None:
514 tape.trainable_params = self._argnums[i][j]
--> 515 new_tapes, fn = transform(tape, *targs, **tkwargs)
516 execution_tapes.extend(new_tapes)
518 fns.append(fn)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:242, in verify_operations(tape, grad_method, qjit_device)
238 _paramshift_op_checker(op)
240 return (in_inverse, in_control)
--> 242 _verify_nested(tape, (False, False), _op_checker)
244 return (tape,), lambda x: x[0]
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:60, in _verify_nested(tape, state, op_checker_fn)
58 ctx = EvaluationContext.get_main_tracing_context()
59 for op in tape.operations:
---> 60 inner_state = op_checker_fn(op, state)
61 if has_nested_tapes(op):
62 for region in nested_quantum_regions(op):
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:236, in verify_operations.<locals>._op_checker(op, state)
234 _mcm_op_checker(op)
235 if grad_method == "adjoint":
--> 236 _adj_diff_op_checker(op)
237 elif grad_method == "parameter-shift":
238 _paramshift_op_checker(op)
File ~/miniconda3/envs/pennylane/lib/python3.12/site-packages/catalyst/device/verification.py:154, in verify_operations.<locals>._adj_diff_op_checker(op)
150 op_name = op.name
151 if not qjit_device.qjit_capabilities.native_ops.get(
152 op_name, EMPTY_PROPERTIES
153 ).differentiable:
--> 154 raise DifferentiableCompileError(
155 f"{op.name} is non-differentiable on '{qjit_device.original_device.name}' device"
156 )
DifferentiableCompileError: StatePrep is non-differentiable on 'lightning.qubit' device
qml.about():
Name: PennyLane
Version: 0.38.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/joey/miniconda3/envs/pennylane/lib/python3.12/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning
Platform info: macOS-14.4.1-arm64-arm-64bit
Python version: 3.12.4
Numpy version: 1.26.4
Scipy version: 1.12.0
Installed devices:
- nvidia.custatevec (PennyLane-Catalyst-0.8.1)
- nvidia.cutensornet (PennyLane-Catalyst-0.8.1)
- oqc.cloud (PennyLane-Catalyst-0.8.1)
- softwareq.qpp (PennyLane-Catalyst-0.8.1)
- lightning.qubit (PennyLane_Lightning-0.38.0)
- default.clifford (PennyLane-0.38.0)
- default.gaussian (PennyLane-0.38.0)
- default.mixed (PennyLane-0.38.0)
- default.qubit (PennyLane-0.38.0)
- default.qubit.autograd (PennyLane-0.38.0)
- default.qubit.jax (PennyLane-0.38.0)
- default.qubit.legacy (PennyLane-0.38.0)
- default.qubit.tf (PennyLane-0.38.0)
- default.qubit.torch (PennyLane-0.38.0)
- default.qutrit (PennyLane-0.38.0)
- default.qutrit.mixed (PennyLane-0.38.0)
- default.tensor (PennyLane-0.38.0)
- null.qubit (PennyLane-0.38.0)
Hi @joeybarreto, unfortunately this is a known bug in Catalyst at the moment; see #1065 for more details.
In short, this is because without @qjit
, PennyLane is correctly identifying operations which need to be decomposed for gradient computation, but is not doing the same when @qjit
is enabled.
Instead, when @qjit
is enabled, Catalyst is verifying that all operators are differentiable, even when certain operations do not need to be differentiated.
We are working to fix this, for now, here is a workaround:
Nq = 2
init_state = np.array([1, 0, 0, 0])
dev = qml.device('lightning.qubit', wires=Nq)
@qml.qnode(dev)
def qnode_test(angles, init_state):
qml.BasisState.compute_decomposition(init_state, wires=range(Nq))
qml.RY(angles[0], wires=0)
qml.RY(angles[1], wires=1)
return qml.expval(qml.PauliZ(0))
@qjit
def jgrad_qjit(angles):
g = grad(qnode_test, argnums=0)
return g(angles, init_state)
>>> angles = jnp.array([0.1,0.2])
>>> jax.grad(qnode_test)(angles, init_state=init_state)
Array([0.09983342, 0. ], dtype=float64)
>>> jgrad_qjit(angles)
Array([ 9.98334166e-02, -5.55111512e-17], dtype=float64)
Note that I have swapped StatePrep
for BasisState
, which has a more efficient decomposition for basis states.
Thanks, yeah I had seen #1065 but wasn't 100% sure if this was the same issue. In reality my initial states are pretty complicated and not just something like [1 0 0 0]
(which I just put as a dummy value in this example), and I expect that trying to decompose them numerically will be way too costly since I'm looking at circuits with 14-20+ qubits. I am working on finding explicit, analytic gate decompositions but until then, I am stuck with qml.StatePrep
.
I do have a follow-up question though, regarding the execution time of the two compiled functions above (whether or not qml.StatePrep
is used). I modified the circuit to be a bit more substantial, like below:
Nq = 18
def test(angles):
for kk in range(5):
for ii in range(Nq):
qml.RY(angles[ii], wires=ii)
for ii in range(0,Nq,2):
qml.CNOT(wires=[ii % Nq, (ii+1) % Nq])
return qml.expval(qml.PauliZ(0))
qnode_test = qml.QNode(test,
qml.device('lightning.qubit', wires=Nq),
interface='jax',
diff_method='best')
@qjit
def jgrad_qjit(angles):
g = grad(qnode_test)
return g(angles)
jgrad = jax.jit(jax.grad(qnode_test))
angles = np.array([0.1]*Nq)
After compilation, one call to jgrad
takes 89ms, while jgrad_qjit
takes 2.5s per call (at 18 qubits). I tested this over different numbers of qubits (with constant circuit depth), and jgrad_qjit
is consistently slower than just vanilla JAX by more than 10x beyond Nq=14. Is this expected? I can open a separate issue if the answer is not trivial. See below for the scaling comparison
Nqs = [2,4,6,8,10,12,14,16,18,20,22]
ts1 = [894e-6, 1.5e-3, 2.15e-3, 2.9e-3, 3.65e-3, 5.06e-3, 8.8e-3, 23.9e-3, 89.3e-3, 446e-3, 1.86]
ts2 = [3.02e-3, 3.51e-3, 4.07e-3, 5.13e-3, 8.47e-3, 23.3e-3, 104e-3, 499e-3, 2.51, 13.2, 73]
In reality my initial states are pretty complicated and not just something like [1 0 0 0] (which I just put as a dummy value in this example), and I expect that trying to decompose them numerically will be way too costly since I'm looking at circuits with 14-20+ qubits. I am working on finding explicit, analytic gate decompositions but until then, I am stuck with qml.StatePrep.
Ah I see! You could try this approach:
Nq = 2
init_state = np.array([1, 0, 0, 0])
dev = qml.device('lightning.qubit', wires=Nq)
@qml.qnode(dev)
def qnode_test(angles):
qml.StatePrep.compute_decomposition(init_state, wires=range(Nq))
qml.RY(angles[0], wires=0)
qml.RY(angles[1], wires=1)
return qml.expval(qml.PauliZ(0))
@qjit
def jgrad_qjit(angles):
g = grad(qnode_test, argnums=0)
return g(angles)
Note a couple of things here:
- I am using
np.array
to set the initial state, notjnp.array
. - I am not passing the initial state as a function argument
Both of these are important, as it means that the state decomposition will happen in Python -- JAX will not try to compile it, which we have noticed is significantly costly.
Thank you for the suggestion, I just tried that on my minimal 14 qubit example (which involves a lot of project code not shown above), after ~5 minutes compiling the gradient function has yet to finish. (Also, I make sure to only pass in initial states as numpy arrays and to provide them during partial completions before creating my qnodes, so I don't expect that JAX is tracing them).
However, the benchmarking plot I shared above suggests a deeper issue. Note that in that example, I do not use any initial state prep, I am just comparing two different ways of jitting the gradient of the qnode. I find that Catalyst (via qjit
) is much slower than direct JAX usage, which suggests that trying to get qml.StatePrep to work in Catalyst is a non-issue if Catalyst is substantially slower than the alternative even before any kind of state preparation gets involved. I don't mind directly using JAX, but it seems Catalyst is meant to be better suited to Pennylane and/or more convenient due to Autograph. Yet, Catalyst gradients settle to being around 1.5 OOM slower than direct JAX gradients, so I'm wondering if I'm misunderstanding its usage or if it is indeed better to use JAX directly on circuits with many qubits.
Hi @joeybarreto, thanks for sharing that benchmark! I think the difference comes from an unfortunate default value for the diff_method="best"
parameter, which in the case of Catalyst would use the paremeter-shift method. Using diff_method="adjoint"
should show comparable results. On my machine I'm now seeing 116 ms
for catalyst.
Thanks for pointing that out @dime10 ! Where is the default value specified? It seems like grad
has a default value of method=None
, but if it were auto
, then I would expect it to use adjoint
since using the call print(qnode_test.best_method_str(dev, qnode_test.interface))
described here, I get a printed value of adjoint
when diff_method = best
and when I create my qnode via
qnode_test = qml.QNode(test,
qml.device('lightning.qubit', wires=Nq),
interface='jax',
diff_method='best')
Sorry, default may not have been the best term, what I meant is the value determined to be used when "best"
is chosen by the user. In PennyLane, this will be adjoint
for the lightning device, but not in Catalyst. The value is chosen here in Catalyst, and for compatibility reasons it was set to parameter-shift
early on (that is, there is no "smart", device-aware choice that is implemented yet):
catalyst/frontend/catalyst/jax_primitives.py
Line 573 in 934726f
We'll be sure to update this soon for better default performance.