PennyLaneAI/catalyst

Catalyst does not support QJIT-compiling a parameterized circuit with `qml.FlipSign`

Opened this issue · 4 comments

We discovered this issue when attempting to QJIT-compile a circuit implementing Grover's algorithm.

Consider the following PennyLane program that applies the qml.FlipSign operator:

import numpy as np
import pennylane as qml

NUM_QUBITS = 2

dev = qml.device("lightning.qubit", wires=NUM_QUBITS)

@qml.qnode(dev)
def circuit(basis_state):
    wires = list(range(NUM_QUBITS))
    qml.FlipSign(basis_state, wires=wires)
    return qml.state()

basis_state = np.array([0., 0.])
state = circuit(basis_state)

As expected, the circuit flips the sign of the $|00\rangle$ basis state:

>>> print(state)
[-1.-0.j  0.+0.j  0.+0.j  0.+0.j]

When we attempt to QJIT-compile and execute this circuit, we get an error:

import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit

NUM_QUBITS = 2

dev = qml.device("lightning.qubit", wires=NUM_QUBITS)

@qjit
@qml.qnode(dev)
def circuit(basis_state):
    wires = list(range(NUM_QUBITS))
    qml.FlipSign(basis_state, wires=wires)
    return qml.state()

basis_state = jnp.array([0., 0.])
state = circuit(basis_state)
Traceback (most recent call last):
...
  File ".../venv/lib/python3.12/site-packages/catalyst/device/decomposition.py", line 82, in catalyst_decomposer
    return op.decomposition()
           ^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/pennylane/operation.py", line 1337, in decomposition
    return self.compute_decomposition(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/pennylane/templates/subroutines/flip_sign.py", line 144, in compute_decomposition
    if arr_bin[-1] == 0:
       ^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 712, in __bool__
    return self.aval._bool(self)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 1475, in error
    raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The error occurred in the FlipSign.compute_decomposition() method:

@staticmethod
def compute_decomposition(wires, arr_bin):
    op_list = []

    if arr_bin[-1] == 0:
        op_list.append(qml.X(wires[-1]))

    op_list.append(qml.ctrl(qml.Z(wires[-1]), control=wires[:-1], control_values=arr_bin[:-1]))

    if arr_bin[-1] == 0:
        op_list.append(qml.X(wires[-1]))

    return op_list

The problem is in statements like if arr_bin[-1] == 0, where in the jitted case, arr_bin is a traced JAX array that is being used in Python control flow, which is not allowed.

Compiling the circuit with AutoGraph, @qjit(autograph=True), gives the same error, because AutoGraph is disabled by default for any module in PennyLane. To try to get around this issue, we followed the Adding modules for Autograph conversion docs and tried the following, which results in a different error:

import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit

NUM_QUBITS = 2

dev = qml.device("lightning.qubit", wires=NUM_QUBITS)

@qjit(autograph=True, autograph_include=["pennylane.templates.subroutines.flip_sign"])
@qml.qnode(dev)
def circuit(basis_state):
    wires = list(range(NUM_QUBITS))
    qml.FlipSign(basis_state, wires=wires)
    return qml.state()

basis_state = jnp.array([0.0, 0.0])
state = circuit(basis_state)
Traceback (most recent call last):
...
  File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 579, in converted_call
    return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/malt/impl/api.py", line 380, in converted_call
    result = converted_f(*effective_args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_1_ucoey.py", line 35, in ag____call__
    ag__.if_stmt(ag__.converted_call(ag__.ld(enabled), (), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 132, in if_stmt
    results = functional_cond()
              ^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 736, in __call__
    return self._call_with_quantum_ctx(ctx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 662, in _call_with_quantum_ctx
    _assert_cond_result_structure([s.out_tree() for s in out_sigs])
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 1319, in _assert_cond_result_structure
    raise TypeError(
TypeError: Conditional requires a consistent return structure across all branches! Got PyTreeDef((*, CustomNode(FlipSign[(Wires([0, 1]), (('n', (Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>, Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>)),))], []))) and PyTreeDef((*, *)).

The appropriate changes to Catalyst and/or PennyLane should be made to add support for the qml.FlipSign operator in QJIT-compiled circuits, where the basis-state input to qml.FlipSign is an input argument to the parameterized circuit.

dime10 commented

Can you re-raise this in the PL repo?

I suspect there might be an equivalent issue/PR/fix on the PL side now, worth double checking. I saw it being discussed (FlipSign not being JIT compatible was interfering with program capture)

@joeycarter I remember tagging you in a slack thread about it recently :)

@josh146 Yes but it was related to turning on Grover's algo in the benchmarks :)

https://xanaduhq.slack.com/archives/C06S7LJV8SX/p1740595287516859

I don't see any issues open or closed in the PennyLane repo about this, but I can quickly test it again and if it doesn't work I can open up an issue there.

I take it back! There's a draft PR here: PennyLaneAI/pennylane#7127