wilson-labs/cola

[Bug] New interaction issue when Tree Flattening / Equinox Functionality

adam-hartshorne opened this issue · 1 comments

It appears that the changes to the code over the past couple of days have broken its compatibility with pytree flattening when called via equinox.filter functionality.

    opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 130, in filter
    return jtu.tree_map(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 251, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 251, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Mismatch custom node data: [('A',), ('dtype', dtype('float64')), ('shape', (45, 45)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set()), ('device', gpu(id=0))] != [('device', gpu(id=0)), ('A',), ('dtype', dtype('float64')), ('shape', (45, 45)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set())]; value: <45x45 Dense with dtype=float64>.

The issue appears to be related if you want to store a linear operator as a variable in an equinox class. Here is a silly MVE,

import equinox as eqx
import jax
import cola

class Linear(eqx.Module):
    weight: cola.ops.LinearOperator #jax.Array 
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = cola.ops.Dense(jax.random.normal(wkey, (out_size, in_size)))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        y = cola.ops.Dense(x) @ cola.ops.Transpose(self.weight) + cola.ops.Dense(self.bias)
        return y.to_dense()
        # return self.weight @ x + self.bias

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)


model = Linear(2, 3, key=jax.random.PRNGKey(0))

eqx.filter(model, eqx.is_inexact_array)

ValueError: Mismatch custom node data: [('A',), ('dtype', dtype('float32')), ('shape', (3, 2)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set()), ('device', gpu(id=0))] != [('device', gpu(id=0)), ('A',), ('dtype', dtype('float32')), ('shape', (3, 2)), ('xnp', <module 'cola.jax_fns' from '/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py'>), ('annotations', set())]; value: <3x2 Dense with dtype=float32>.

mfinzi commented

Ah I see, the vars(obj) dict order will not necessarily be an invariant. Fixed in #39 .
Incidentally we've also made it so that the flattened pytrees will no longer include ints and other non array leaf nodes and instead these will be in the static fields.