/refx

Primary LanguagePythonMIT LicenseMIT

Refx

A reference system for jax

What is Refx?

Refx defines a simple reference system that allows you to create DAGs on top of JAX PyTrees. It enables some key features:

  • Shared state
  • Tractable mutability
  • Semantic partitioning

Refx is intended to be a low-level library that can be used as a building block within other JAX libraries.

Why Refx?

Functional systems like flax and haiku are powerful but add a lot of complexity that is often transfered to the user. On the other hand, pytree-based systems like equinox are simpler but lack the ability to share parameters and modules.

Refx aims to create a system that can be used to build neural networks libraries that has the simplicity of pytree-based systems while also having the power of functional systems.

Installation

pip install refx

Getting Started

Refx's main data structure is the Ref class. It is a wrapper around a value that can be used as leaves in a pytree. It also has a value attribute that can be used to access and mutate the value.

import jax
import refx

r1 = refx.Ref(1)
r2 = refx.Ref(2)

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

pytree['a'][0].value = 10

assert pytree['a'][1].value == 10

Ref is not a pytree node, therefore you cannot pass pytrees containing Refs to JAX functions. To interact with JAX, refx provides the following functions:

  • deref: converts a pytree of references to a pytree of values and indexes.
  • reref: converts a pytree of values and indexes to a pytree of references.

deref must be called before crossing a JAX boundary and reref must be called after crossing a JAX boundary.

pytree = refx.deref(pytree)

@jax.jit
def f(pytree):
    pytree = refx.reref(pytree)
    # perform some computation / update the references
    pytree['a'][2].value = 50
    return jax.deref(pytree)

pytree = f(pytree)
pytree = refx.reref(pytree)

assert pytree['b'].value == 50

As you see in the is example, we've effectively implemented shared state and tracktable mutability with pure pytrees.

Trace-level awareness

In JAX, unconstrained mutability can lead to tracer leakage. To prevent this, refx only allows mutating references from the same trace level they were created in.

r = refx.Ref(1)

@jax.jit
def f():
    # ValueError: Cannot mutate ref from different trace level
    r.value = jnp.array(1.0)
    ...

Partitioning

Each reference has a collection: Hashable attribute that can be used to partition references into different groups. refx provides the tree_partition and to partition a pytree based a predicate function.

r1 = refx.Ref(1, collection="params")
r2 = refx.Ref(2, collection="batch_stats")

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

(params, rest), treedef = refx.tree_partition(
    pytree, lambda x: isinstance(x, refx.Ref) and x.collection == "params")

assert params == {
    ('a', '0'): r1,
    ('a', '1'): r1,
    ('a', '2'): refx.NOTHING,
    ('b',): refx.NOTHING
}
assert rest == {
    ('a', '0'): refx.NOTHING,
    ('a', '1'): refx.NOTHING,
    ('a', '2'): r2,
    ('b',): r2,
}

You can use more partitioning functions to partition a pytree into multiple groups, tree_partition will always return one more partition than the number of functions passed to it. The last partition (rest) will contain all the remaining elements of the pytree. tree_partition also returns a treedef that can be used to reconstruct the pytree by using the merge_partitions function:

pytree = refx.merge_partitions(partitions, treedef)

If you only need a single partition, you can use the get_partition function:

r1 = refx.Ref(1, collection="params")
r2 = refx.Ref(2, collection="batch_stats")

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

params = refx.get_partition(
    pytree, lambda x: isinstance(x, refx.Ref) and x.collection == "params")

assert params == {
    ('a', '0'): r1,
    ('a', '1'): r1,
    ('a', '2'): refx.NOTHING,
    ('b',): refx.NOTHING
}

Updating references

refx provides the update_refs function to update references in a pytree. It updates a target pytree with references reference with the values from a source pytree. The source pytree can be either a pytree of references or a pytree of values. As an example, here is how you can use update_refs to perform gradient descent on a pytree of references:

def by_collection(c):
    return lambda x: isinstance(x, refx.Ref) and x.collection == c

(params, rest), treedef = refx.tree_partition(
    refx.deref(pytree), by_collection("params"))

def loss(params):
    pytree = refx.merge_partitions((params, rest), treedef)
    ...

grads = jax.grad(loss)(params)

# gradient descent
params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads)

refx.update_refs(
    refx.get_partition(pytree, by_collection("params")), params)