`qml.Hamiltonian` inside of a `pure_callback` doesn't work
isaacdevlugt opened this issue ยท 6 comments
Issue description
Trying to use qml.qchem.molecular_hamiltonian
inside of a callback and it's not working.
-
Expected behavior: I expect that generating a
Hamiltonian
in a callback should be possible since aHamiltonian
is a pytree. -
Actual behavior: It doesn't work.
-
Reproduces how often: 100%
-
System information:
Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: [/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages)
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning
Platform info: macOS-14.6.1-arm64-arm-64bit
Python version: 3.11.9
Numpy version: 1.26.4
Scipy version: 1.12.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.37.0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0)
- oqc.cloud (PennyLane-Catalyst-0.7.0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0)
- default.clifford (PennyLane-0.37.0)
- default.gaussian (PennyLane-0.37.0)
- default.mixed (PennyLane-0.37.0)
- default.qubit (PennyLane-0.37.0)
- default.qubit.autograd (PennyLane-0.37.0)
- default.qubit.jax (PennyLane-0.37.0)
- default.qubit.legacy (PennyLane-0.37.0)
- default.qubit.tf (PennyLane-0.37.0)
- default.qubit.torch (PennyLane-0.37.0)
- default.qutrit (PennyLane-0.37.0)
- default.qutrit.mixed (PennyLane-0.37.0)
- default.tensor (PennyLane-0.37.0)
- null.qubit (PennyLane-0.37.0)
Source code and tracebacks
coordinates = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.1]])
symbols = ['H', 'H']
# Construct the Molecule object
molecule = qchem.Molecule(symbols, coordinates)
H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion')
data, shape = jax.tree_util.tree_flatten(H)
abstract = jax._src.api_util.shaped_abstractify(jnp.array(data))
H_abstract = jax.tree_util.tree_unflatten(shape, abstract)
@catalyst.pure_callback
def get_hamiltonian(coords, molecule) -> (H_abstract, int):
H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') # can't be jit'd because of deep numpy calls
return H, qubits
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[11], line 11
8 data, shape = jax.tree_util.tree_flatten(H)
10 abstract = jax._src.api_util.shaped_abstractify(jnp.array(data))
---> 11 H_abstract = jax.tree_util.tree_unflatten(shape, abstract)
13 @catalyst.pure_callback
14 def get_hamiltonian(coords, molecule) -> (H_abstract, int):
15 H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') # can't be jit'd because of deep numpy calls
File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/tree_util.py:100, in tree_unflatten(treedef, leaves)
86 def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
87 """Reconstructs a pytree from the treedef and the leaves.
88
89 The inverse of :func:`tree_flatten`.
(...)
98 described by ``treedef``.
99 """
--> 100 return treedef.unflatten(leaves)
TypeError: unflatten(): incompatible function arguments. The following argument types are supported:
1. (self: jaxlib.xla_extension.pytree.PyTreeDef, arg0: Iterable) -> object
Invoked with: PyTreeDef(CustomNode(Sum[(None,)], [CustomNode(SProd[()], [*, CustomNode(Identity[(<Wires = [0]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [0]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [1]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [2]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [3]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [1]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliY[(<Wires = [0]>, ())], []), CustomNode(PauliX[(<Wires = [1]>, ())], []), CustomNode(PauliX[(<Wires = [2]>, ())], []), CustomNode(PauliY[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliY[(<Wires = [0]>, ())], []), CustomNode(PauliY[(<Wires = [1]>, ())], []), CustomNode(PauliX[(<Wires = [2]>, ())], []), CustomNode(PauliX[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliX[(<Wires = [0]>, ())], []), CustomNode(PauliX[(<Wires = [1]>, ())], []), CustomNode(PauliY[(<Wires = [2]>, ())], []), CustomNode(PauliY[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliX[(<Wires = [0]>, ())], []), CustomNode(PauliY[(<Wires = [1]>, ())], []), CustomNode(PauliY[(<Wires = [2]>, ())], []), CustomNode(PauliX[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [2]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [1]>, ())], []), CustomNode(PauliZ[(<Wires = [2]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [1]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [2]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])])])), ShapedArray(float64[15])
Additional information
Any additional information, configuration or data that might be necessary
to reproduce the issue.
Your traceback seems to indicate the error occurs outside of the callback
@dime10 I told @isaacdevlugt to create an issue so that I can narrow it down for him. This is not urgent.
Thanks @dime10! Yes there's a bit of context missing ๐
Note that the following does work (assuming you know in advance how many terms the Hamiltonian has):
data, shape = jax.tree_util.tree_flatten(H)
def get_hamiltonian(coordinates):
molecule = qml.qchem.Molecule(["H", "H"], coordinates)
H, qubits = qml.qchem.molecular_hamiltonian(molecule)
return H
@qml.qjit
def f(coordinates):
return catalyst.pure_callback(get_hamiltonian, result_type=[jax.ShapeDtypeStruct([], dtype=float)] * 15)(coordinates)
>>> f(coordinates)
[Array(9.7983828, dtype=float64),
Array(0.3060033, dtype=float64),
Array(0.3060033, dtype=float64),
Array(0.19345955, dtype=float64),
Array(-0.74817268, dtype=float64),
Array(0.15271652, dtype=float64),
Array(0.19164807, dtype=float64),
Array(0.03893155, dtype=float64),
Array(-0.03893155, dtype=float64),
Array(-0.03893155, dtype=float64),
Array(0.03893155, dtype=float64),
Array(-0.74817268, dtype=float64),
Array(0.19164807, dtype=float64),
Array(0.15271652, dtype=float64),
Array(0.20376722, dtype=float64)]
However, attempting to include the pytree struct of the Hamiltonian in the result_type
leads to a failure in Catalyst validation.
@isaacdevlugt something to note here: the underlying solution is that we should ensure that qml.qchem.molecule
and qml.qchem.molecular_hamiltonian
should be jax.jit compatible end-to-end -- there is no need for them to try and convert JAX arrays to NumPy arrays. This should negate the need to try and use a callback.
@josh146 agreed that molecule
and molecular_hamiltonian
could be jit-compatible!