PennyLaneAI/catalyst

[BUG] value_and_grad fails when quantum circuit has multiple inputs.

mehrdad2m opened this issue · 6 comments

dev = qml.device("lightning.qubit", wires=1)
@qml.qnode(dev)
def circuit(x, y, z):
    qml.RY(x, wires=0)
    qml.RX(y, wires=0)
    qml.RX(z, wires=0)
    return qml.expval(qml.PauliZ(0))
result_val, result_grad = qjit(value_and_grad(circuit, argnum=[1]))(0.1, 0.2, 0.3)
print(result_val, result_grad)

expected result:

0.8731983044562817 (Array(-0.47703041, dtype=float64),)

Actual:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-vmap.py", line 31, in <module>
    result_val, result_grad = qjit(value_and_grad(circuit, argnum=[1]))(0.1, 0.2, 0.3)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 455, in __call__
    requires_promotion = self.jit_compile(args)
                         ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 526, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
                                                              ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 606, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
                               ^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 531, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2362, in trace_to_jaxpr_dynamic2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2377, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 604, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 623, in __call__
    gradients = _unflatten_derivatives(
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 711, in _unflatten_derivatives
    intermediate_results = tree_unflatten(in_tree, intermediate_results)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 132, in tree_unflatten
    return treedef.unflatten(leaves)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Too few leaves for PyTreeDef; expected 3, got 1

Nice catch @mehrdad2m!

I can reduce this to a more minimal example:

@qjit
def g(x, y, z):
    def f(x, y, z):
        return x * y ** 2 * jnp.sin(z)
    return catalyst.value_and_grad(f)(x, y, z)
>>> g(0.1, 0.2, 0.3)
ValueError: Too few leaves for PyTreeDef; expected 3, got 1

This works with jax:

@qjit
def g(x, y, z):
    def f(x, y, z):
        return x * y ** 2 * jnp.sin(z)
    return jax.value_and_grad(f)(x, y, z)
>>> g(0.1, 0.2, 0.3)
(Array(0.00118208, dtype=float64), Array(0.01182081, dtype=float64))

As part of this, I noticed that the catalyst.value_and_grad argnum argument also fails, but with a different error:

@qjit
def g(x, y, z):
    def f(x, y, z):
        return x * y ** 2 * jnp.sin(z)
    return catalyst.value_and_grad(f, argnum=[0, 1, 2])(x, y, z)
>>> g(0.1, 0.2, 0.3)
CompileError: Compilation failed:
g:3:12: error: 'func.return' op has 5 operands, but enclosing function (@f.fullgrad012) returns 4
    %0:4 = "gradient.value_and_grad"(%arg0, %arg1, %arg2) {callee = @f, diffArgIndices = dense<[0, 1, 2]> : tensor<3xi64>, method = "auto"} : (tensor<f64>, tensor<f64>, tensor<f64>) -> (tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>)
           ^
g:3:12: note: see current operation: "func.return"(%5#0, %5#1, %5#2, %5#3, %5#4) : (tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>) -> ()
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module

@paul0403 I assume there are tests for value_and_grad with multiple arguments (due to the presence of the argnum argument), do you have a sense from looking at the examples above what the edge case might be?

Hmm, shouldn't this example return 3 gradients, each being the partial on x, y, and z? At least I think that's what value_and_grad expects. There are indeed tests that cover multiple differentiable arguments (https://github.com/PennyLaneAI/catalyst/blob/main/frontend/test/pytest/test_gradient.py#L265) but they all follow the above logic. Not sure why this example just gets one gradient.

In fact the tests wrap multiple arguments in an array. Writing that instead of raw multiple arguments fixes the issue:

@qjit
def g(vec):
    def f(v):
        return v[0] * v[1] ** 2 * jnp.sin(v[2])
    return catalyst.value_and_grad(f)(vec)

print(g(jnp.array([0.1, 0.2, 0.3])))

>>>
(Array(0.00118208, dtype=float64), Array([0.01182081, 0.01182081, 0.00382135], dtype=float64))

@paul0403 the default should be argnum=0, hence why only 1 gradient is returned. If you instead specify argnum=[0, 1, 2], you should get all three:

@qjit(static_argnums=3)
def g(x, y, z, argnum):
    def f(x, y, z):
        return x * y ** 2 * jnp.sin(z)
    return jax.value_and_grad(f, argnums=argnum)(x, y, z)
>>> g(0.1, 0.2, 0.3, 0)
(Array(0.00118208, dtype=float64), Array(0.01182081, dtype=float64))
>>> g(0.1, 0.2, 0.3, (0, 1, 2))
(Array(0.00118208, dtype=float64),
 (Array(0.01182081, dtype=float64),
  Array(0.01182081, dtype=float64),
  Array(0.00382135, dtype=float64)))

In fact the tests wrap multiple arguments in an array. Writing that instead of raw multiple arguments fixes the issue

Ah, so it sounds like multiple inputs for value_and_grad was not implemented?