[BUG] value_and_grad fails when quantum circuit has multiple inputs.
mehrdad2m opened this issue · 6 comments
dev = qml.device("lightning.qubit", wires=1)
@qml.qnode(dev)
def circuit(x, y, z):
qml.RY(x, wires=0)
qml.RX(y, wires=0)
qml.RX(z, wires=0)
return qml.expval(qml.PauliZ(0))
result_val, result_grad = qjit(value_and_grad(circuit, argnum=[1]))(0.1, 0.2, 0.3)
print(result_val, result_grad)
expected result:
0.8731983044562817 (Array(-0.47703041, dtype=float64),)
Actual:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/test-vmap.py", line 31, in <module>
result_val, result_grad = qjit(value_and_grad(circuit, argnum=[1]))(0.1, 0.2, 0.3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 455, in __call__
requires_promotion = self.jit_compile(args)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 526, in jit_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 606, in capture
jaxpr, out_type, treedef = trace_to_jaxpr(
^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 531, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2362, in trace_to_jaxpr_dynamic2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2377, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 604, in fn_with_transform_named_sequence
return self.user_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 623, in __call__
gradients = _unflatten_derivatives(
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 711, in _unflatten_derivatives
intermediate_results = tree_unflatten(in_tree, intermediate_results)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 132, in tree_unflatten
return treedef.unflatten(leaves)
^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Too few leaves for PyTreeDef; expected 3, got 1
Nice catch @mehrdad2m!
I can reduce this to a more minimal example:
@qjit
def g(x, y, z):
def f(x, y, z):
return x * y ** 2 * jnp.sin(z)
return catalyst.value_and_grad(f)(x, y, z)
>>> g(0.1, 0.2, 0.3)
ValueError: Too few leaves for PyTreeDef; expected 3, got 1
This works with jax
:
@qjit
def g(x, y, z):
def f(x, y, z):
return x * y ** 2 * jnp.sin(z)
return jax.value_and_grad(f)(x, y, z)
>>> g(0.1, 0.2, 0.3)
(Array(0.00118208, dtype=float64), Array(0.01182081, dtype=float64))
As part of this, I noticed that the catalyst.value_and_grad
argnum argument also fails, but with a different error:
@qjit
def g(x, y, z):
def f(x, y, z):
return x * y ** 2 * jnp.sin(z)
return catalyst.value_and_grad(f, argnum=[0, 1, 2])(x, y, z)
>>> g(0.1, 0.2, 0.3)
CompileError: Compilation failed:
g:3:12: error: 'func.return' op has 5 operands, but enclosing function (@f.fullgrad012) returns 4
%0:4 = "gradient.value_and_grad"(%arg0, %arg1, %arg2) {callee = @f, diffArgIndices = dense<[0, 1, 2]> : tensor<3xi64>, method = "auto"} : (tensor<f64>, tensor<f64>, tensor<f64>) -> (tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>)
^
g:3:12: note: see current operation: "func.return"(%5#0, %5#1, %5#2, %5#3, %5#4) : (tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>) -> ()
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module
@paul0403 I assume there are tests for value_and_grad
with multiple arguments (due to the presence of the argnum
argument), do you have a sense from looking at the examples above what the edge case might be?
Hmm, shouldn't this example return 3 gradients, each being the partial on x, y, and z? At least I think that's what value_and_grad
expects. There are indeed tests that cover multiple differentiable arguments (https://github.com/PennyLaneAI/catalyst/blob/main/frontend/test/pytest/test_gradient.py#L265) but they all follow the above logic. Not sure why this example just gets one gradient.
In fact the tests wrap multiple arguments in an array. Writing that instead of raw multiple arguments fixes the issue:
@qjit
def g(vec):
def f(v):
return v[0] * v[1] ** 2 * jnp.sin(v[2])
return catalyst.value_and_grad(f)(vec)
print(g(jnp.array([0.1, 0.2, 0.3])))
>>>
(Array(0.00118208, dtype=float64), Array([0.01182081, 0.01182081, 0.00382135], dtype=float64))
@paul0403 the default should be argnum=0
, hence why only 1 gradient is returned. If you instead specify argnum=[0, 1, 2]
, you should get all three:
@qjit(static_argnums=3)
def g(x, y, z, argnum):
def f(x, y, z):
return x * y ** 2 * jnp.sin(z)
return jax.value_and_grad(f, argnums=argnum)(x, y, z)
>>> g(0.1, 0.2, 0.3, 0)
(Array(0.00118208, dtype=float64), Array(0.01182081, dtype=float64))
>>> g(0.1, 0.2, 0.3, (0, 1, 2))
(Array(0.00118208, dtype=float64),
(Array(0.01182081, dtype=float64),
Array(0.01182081, dtype=float64),
Array(0.00382135, dtype=float64)))
In fact the tests wrap multiple arguments in an array. Writing that instead of raw multiple arguments fixes the issue
Ah, so it sounds like multiple inputs for value_and_grad
was not implemented?