[Bug] New interaction issue when Tree Flattening / Equinox Functionality
adam-hartshorne opened this issue · 1 comments
It appears that the changes to the code over the past couple of days have broken its compatibility with pytree flattening when called via equinox.filter functionality.
opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 130, in filter
return jtu.tree_map(
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 251, in tree_map
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 251, in <listcomp>
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Mismatch custom node data: [('A',), ('dtype', dtype('float64')), ('shape', (45, 45)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set()), ('device', gpu(id=0))] != [('device', gpu(id=0)), ('A',), ('dtype', dtype('float64')), ('shape', (45, 45)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set())]; value: <45x45 Dense with dtype=float64>.
The issue appears to be related if you want to store a linear operator as a variable in an equinox class. Here is a silly MVE,
import equinox as eqx
import jax
import cola
class Linear(eqx.Module):
weight: cola.ops.LinearOperator #jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = cola.ops.Dense(jax.random.normal(wkey, (out_size, in_size)))
self.bias = jax.random.normal(bkey, (out_size,))
def __call__(self, x):
y = cola.ops.Dense(x) @ cola.ops.Transpose(self.weight) + cola.ops.Dense(self.bias)
return y.to_dense()
# return self.weight @ x + self.bias
@jax.jit
@jax.grad
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)
model = Linear(2, 3, key=jax.random.PRNGKey(0))
eqx.filter(model, eqx.is_inexact_array)
ValueError: Mismatch custom node data: [('A',), ('dtype', dtype('float32')), ('shape', (3, 2)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set()), ('device', gpu(id=0))] != [('device', gpu(id=0)), ('A',), ('dtype', dtype('float32')), ('shape', (3, 2)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set())]; value: <3x2 Dense with dtype=float32>.