PennyLaneAI/catalyst

`qml.Hamiltonian` inside of a `pure_callback` doesn't work

isaacdevlugt opened this issue ยท 6 comments

Issue description

Trying to use qml.qchem.molecular_hamiltonian inside of a callback and it's not working.

  • Expected behavior: I expect that generating a Hamiltonian in a callback should be possible since a Hamiltonian is a pytree.

  • Actual behavior: It doesn't work.

  • Reproduces how often: 100%

  • System information:

Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: [/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages)
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.6.1-arm64-arm-64bit
Python version:          3.11.9
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.37.0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0)
- oqc.cloud (PennyLane-Catalyst-0.7.0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0)
- default.clifford (PennyLane-0.37.0)
- default.gaussian (PennyLane-0.37.0)
- default.mixed (PennyLane-0.37.0)
- default.qubit (PennyLane-0.37.0)
- default.qubit.autograd (PennyLane-0.37.0)
- default.qubit.jax (PennyLane-0.37.0)
- default.qubit.legacy (PennyLane-0.37.0)
- default.qubit.tf (PennyLane-0.37.0)
- default.qubit.torch (PennyLane-0.37.0)
- default.qutrit (PennyLane-0.37.0)
- default.qutrit.mixed (PennyLane-0.37.0)
- default.tensor (PennyLane-0.37.0)
- null.qubit (PennyLane-0.37.0)

Source code and tracebacks

coordinates = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.1]])
symbols = ['H', 'H']

# Construct the Molecule object
molecule = qchem.Molecule(symbols, coordinates)

H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') 
data, shape = jax.tree_util.tree_flatten(H)

abstract = jax._src.api_util.shaped_abstractify(jnp.array(data))
H_abstract = jax.tree_util.tree_unflatten(shape, abstract)

@catalyst.pure_callback
def get_hamiltonian(coords, molecule) -> (H_abstract, int):
    H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') # can't be jit'd because of deep numpy calls
    return H, qubits
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[11], line 11
      8 data, shape = jax.tree_util.tree_flatten(H)
     10 abstract = jax._src.api_util.shaped_abstractify(jnp.array(data))
---> 11 H_abstract = jax.tree_util.tree_unflatten(shape, abstract)
     13 @catalyst.pure_callback
     14 def get_hamiltonian(coords, molecule) -> (H_abstract, int):
     15     H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') # can't be jit'd because of deep numpy calls

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/tree_util.py:100, in tree_unflatten(treedef, leaves)
     86 def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
     87   """Reconstructs a pytree from the treedef and the leaves.
     88 
     89   The inverse of :func:`tree_flatten`.
   (...)
     98     described by ``treedef``.
     99   """
--> 100   return treedef.unflatten(leaves)

TypeError: unflatten(): incompatible function arguments. The following argument types are supported:
    1. (self: jaxlib.xla_extension.pytree.PyTreeDef, arg0: Iterable) -> object

Invoked with: PyTreeDef(CustomNode(Sum[(None,)], [CustomNode(SProd[()], [*, CustomNode(Identity[(<Wires = [0]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [0]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [1]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [2]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [3]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [1]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliY[(<Wires = [0]>, ())], []), CustomNode(PauliX[(<Wires = [1]>, ())], []), CustomNode(PauliX[(<Wires = [2]>, ())], []), CustomNode(PauliY[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliY[(<Wires = [0]>, ())], []), CustomNode(PauliY[(<Wires = [1]>, ())], []), CustomNode(PauliX[(<Wires = [2]>, ())], []), CustomNode(PauliX[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliX[(<Wires = [0]>, ())], []), CustomNode(PauliX[(<Wires = [1]>, ())], []), CustomNode(PauliY[(<Wires = [2]>, ())], []), CustomNode(PauliY[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliX[(<Wires = [0]>, ())], []), CustomNode(PauliY[(<Wires = [1]>, ())], []), CustomNode(PauliY[(<Wires = [2]>, ())], []), CustomNode(PauliX[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [2]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [1]>, ())], []), CustomNode(PauliZ[(<Wires = [2]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [1]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [2]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])])])), ShapedArray(float64[15])

Additional information

Any additional information, configuration or data that might be necessary
to reproduce the issue.

Your traceback seems to indicate the error occurs outside of the callback

@dime10 I told @isaacdevlugt to create an issue so that I can narrow it down for him. This is not urgent.

Thanks @dime10! Yes there's a bit of context missing ๐Ÿ˜…

Note that the following does work (assuming you know in advance how many terms the Hamiltonian has):

data, shape = jax.tree_util.tree_flatten(H)

def get_hamiltonian(coordinates):
    molecule = qml.qchem.Molecule(["H", "H"], coordinates)
    H, qubits = qml.qchem.molecular_hamiltonian(molecule)
    return H

@qml.qjit
def f(coordinates):
    return catalyst.pure_callback(get_hamiltonian, result_type=[jax.ShapeDtypeStruct([], dtype=float)] * 15)(coordinates)
>>> f(coordinates)
[Array(9.7983828, dtype=float64),
 Array(0.3060033, dtype=float64),
 Array(0.3060033, dtype=float64),
 Array(0.19345955, dtype=float64),
 Array(-0.74817268, dtype=float64),
 Array(0.15271652, dtype=float64),
 Array(0.19164807, dtype=float64),
 Array(0.03893155, dtype=float64),
 Array(-0.03893155, dtype=float64),
 Array(-0.03893155, dtype=float64),
 Array(0.03893155, dtype=float64),
 Array(-0.74817268, dtype=float64),
 Array(0.19164807, dtype=float64),
 Array(0.15271652, dtype=float64),
 Array(0.20376722, dtype=float64)]

However, attempting to include the pytree struct of the Hamiltonian in the result_type leads to a failure in Catalyst validation.

@isaacdevlugt something to note here: the underlying solution is that we should ensure that qml.qchem.molecule and qml.qchem.molecular_hamiltonian should be jax.jit compatible end-to-end -- there is no need for them to try and convert JAX arrays to NumPy arrays. This should negate the need to try and use a callback.

@josh146 agreed that molecule and molecular_hamiltonian could be jit-compatible!