Catalyst does not support QJIT-compiling a parameterized circuit with `qml.FlipSign`
Opened this issue · 4 comments
We discovered this issue when attempting to QJIT-compile a circuit implementing Grover's algorithm.
Consider the following PennyLane program that applies the qml.FlipSign operator:
import numpy as np
import pennylane as qml
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qml.qnode(dev)
def circuit(basis_state):
wires = list(range(NUM_QUBITS))
qml.FlipSign(basis_state, wires=wires)
return qml.state()
basis_state = np.array([0., 0.])
state = circuit(basis_state)
As expected, the circuit flips the sign of the
>>> print(state)
[-1.-0.j 0.+0.j 0.+0.j 0.+0.j]
When we attempt to QJIT-compile and execute this circuit, we get an error:
import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qjit
@qml.qnode(dev)
def circuit(basis_state):
wires = list(range(NUM_QUBITS))
qml.FlipSign(basis_state, wires=wires)
return qml.state()
basis_state = jnp.array([0., 0.])
state = circuit(basis_state)
Traceback (most recent call last):
...
File ".../venv/lib/python3.12/site-packages/catalyst/device/decomposition.py", line 82, in catalyst_decomposer
return op.decomposition()
^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/pennylane/operation.py", line 1337, in decomposition
return self.compute_decomposition(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/pennylane/templates/subroutines/flip_sign.py", line 144, in compute_decomposition
if arr_bin[-1] == 0:
^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 712, in __bool__
return self.aval._bool(self)
^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 1475, in error
raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
The error occurred in the FlipSign.compute_decomposition() method:
@staticmethod
def compute_decomposition(wires, arr_bin):
op_list = []
if arr_bin[-1] == 0:
op_list.append(qml.X(wires[-1]))
op_list.append(qml.ctrl(qml.Z(wires[-1]), control=wires[:-1], control_values=arr_bin[:-1]))
if arr_bin[-1] == 0:
op_list.append(qml.X(wires[-1]))
return op_list
The problem is in statements like if arr_bin[-1] == 0
, where in the jitted case, arr_bin
is a traced JAX array that is being used in Python control flow, which is not allowed.
Compiling the circuit with AutoGraph, @qjit(autograph=True)
, gives the same error, because AutoGraph is disabled by default for any module in PennyLane. To try to get around this issue, we followed the Adding modules for Autograph conversion docs and tried the following, which results in a different error:
import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qjit(autograph=True, autograph_include=["pennylane.templates.subroutines.flip_sign"])
@qml.qnode(dev)
def circuit(basis_state):
wires = list(range(NUM_QUBITS))
qml.FlipSign(basis_state, wires=wires)
return qml.state()
basis_state = jnp.array([0.0, 0.0])
state = circuit(basis_state)
Traceback (most recent call last):
...
File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 579, in converted_call
return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/malt/impl/api.py", line 380, in converted_call
result = converted_f(*effective_args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file_1_ucoey.py", line 35, in ag____call__
ag__.if_stmt(ag__.converted_call(ag__.ld(enabled), (), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 132, in if_stmt
results = functional_cond()
^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 736, in __call__
return self._call_with_quantum_ctx(ctx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 662, in _call_with_quantum_ctx
_assert_cond_result_structure([s.out_tree() for s in out_sigs])
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 1319, in _assert_cond_result_structure
raise TypeError(
TypeError: Conditional requires a consistent return structure across all branches! Got PyTreeDef((*, CustomNode(FlipSign[(Wires([0, 1]), (('n', (Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>, Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>)),))], []))) and PyTreeDef((*, *)).
The appropriate changes to Catalyst and/or PennyLane should be made to add support for the qml.FlipSign
operator in QJIT-compiled circuits, where the basis-state input to qml.FlipSign
is an input argument to the parameterized circuit.
Can you re-raise this in the PL repo?
I suspect there might be an equivalent issue/PR/fix on the PL side now, worth double checking. I saw it being discussed (FlipSign not being JIT compatible was interfering with program capture)
@joeycarter I remember tagging you in a slack thread about it recently :)
@josh146 Yes but it was related to turning on Grover's algo in the benchmarks :)
https://xanaduhq.slack.com/archives/C06S7LJV8SX/p1740595287516859
I don't see any issues open or closed in the PennyLane repo about this, but I can quickly test it again and if it doesn't work I can open up an issue there.
I take it back! There's a draft PR here: PennyLaneAI/pennylane#7127