[BUG] Verification is not allowing differentiation of non-differentiable gates
josh146 opened this issue · 8 comments
The new gradient verification is slightly too overzealous, and won't allow circuits to be differentiated even when the operation which is not supported for differentiation is not being differentiated:
import pennylane as qml
import numpy as np
from jax import numpy as jnp
dev = qml.device("lightning.qubit", wires=6)
@qml.qjit
@qml.qnode(dev)
def cost(params):
qml.BasisState(np.array([1, 1, 0, 0, 0, 0]), wires=range(6))
qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
return qml.expval(qml.PauliZ(0))
@qml.qjit
def dcost(params):
return qml.grad(cost)(params)
>>> x = jnp.array([0.2, 0.7])
>>> print("Cost:", cost(x))
Cost: -0.7472525151031233
>>> print("Gradient:", dcost(x))
File ~/miniconda3/lib/python3.10/site-packages/catalyst/device/verification.py:234, in verify_operations.<locals>._op_checker(op, state)
232 _mcm_op_checker(op)
233 if grad_method == \"adjoint\":
--> 234 _adj_diff_op_checker(op)
235 elif grad_method == \"parameter-shift\":
236 _paramshift_op_checker(op)
File ~/miniconda3/lib/python3.10/site-packages/catalyst/device/verification.py:152, in verify_operations.<locals>._adj_diff_op_checker(op)
148 op_name = op.name
149 if not qjit_device.qjit_capabilities.native_ops.get(
150 op_name, EMPTY_PROPERTIES
151 ).differentiable:
--> 152 raise DifferentiableCompileError(
153 f\"{op.name} is non-differentiable on '{qjit_device.original_device.name}' device\"
154 )
DifferentiableCompileError: BasisState is non-differentiable on 'lightning.qubit' device"
In this case, it is failing to allow this circuit to pass verification even though BasisState
is not being differentiated.
Previously, this example would work fine, since BasisState
was always being decomposed down to non-parametrizable gates (qml.X
).
Note that this is currently affecting our VQE + catalyst demos, and they are no longer executable. A temporary workaround I can do is:
qml.BasisState.compute_decomposition(np.array([1, 1, 0, 0, 0, 0]), wires=range(6))
but this is not ideal.
Note that if I set diff_method="parameter-shift"
, I get a compilation error:
>>> dcost(x)
dcost:13:3: error: 'func.func' op cloned during the gradient pass is not free of quantum ops:
"func.func"() <{function_type = (tensor<6xi64>, tensor<2xf64>, index) -> tensor<?xf64>, sym_name = "cost.qgrad", sym_visibility = "private"}> ({
^bb0(%arg0: tensor<6xi64>, %arg1: tensor<2xf64>, %arg2: index):
%0 = "arith.constant"() <{value = sparse<15, -1.5707963267948966> : tensor<16xf64>}> : () -> tensor<16xf64>
...
Note that if I set
diff_method="parameter-shift"
, I get a compilation error:>>> dcost(x) dcost:13:3: error: 'func.func' op cloned during the gradient pass is not free of quantum ops: "func.func"() <{function_type = (tensor<6xi64>, tensor<2xf64>, index) -> tensor<?xf64>, sym_name = "cost.qgrad", sym_visibility = "private"}> ({ ^bb0(%arg0: tensor<6xi64>, %arg1: tensor<2xf64>, %arg2: index): %0 = "arith.constant"() <{value = sparse<15, -1.5707963267948966> : tensor<16xf64>}> : () -> tensor<16xf64> ...
@erick-xanadu I think this is related to our interface discussion on your PR, since the new gates aren't implementing one of the quantum gate interfaces, the gradient passes would need to remove them explicitly.
Should be an easy fix.
The main problem mentioned here might be difficult to solve quickly, we need a way to track which gate parameters came from differentiated function arguments.
A workaround would be completing the decomposition of gates not supported for differentiation, which we want anyways. I believe that should be fairly quick to implement. While this is inefficient in some cases (unnecessary decomposing), it does match the previous behaviour for StatePrep.
@erick-xanadu I think this is related to our interface discussion, since the new gates aren't implementing the interface the gradient passes would need to remove them explicitly.
I agree. However, I don't see how we got here because verification should have caught this error similar to above.
@erick-xanadu I think this is related to our interface discussion, since the new gates aren't implementing the interface the gradient passes would need to remove them explicitly.
I agree. However, I don't see how we got here because verification should have caught this error similar to above.
I don't know if verification for parameter-shift has been implemented yet.
I don't know if verification for parameter-shift has been implemented yet.
A naive verification for parameter shift is implemented, just confirming that op.grad_method in {"A", None}
for all the operations. The more thorough verification that was discussed didn't make it in yet.
@erick-xanadu I think this is related to our interface discussion on your PR, since the new gates aren't implementing one of the quantum gate interfaces, the gradient passes would need to remove them explicitly.
Should be an easy fix.
If the parameter-shift bug is an easy fix, should I split this into its own issue separate from the verification discussion?
(we can treat this now as two separate bugs)