/treex-1

Primary LanguageJupyter NotebookMIT LicenseMIT

Treex

Main features:

  • Modules contain their parameters
  • Easy transfer learning
  • Simple initialization
  • No metaclass magic
  • No apply method
  • No need special versions of vmap, jit, and friends.

To prove the previous we will start with by creating a very contrived but complete module which will use everything from parameters, states, and random state:

from typing import Tuple
import jax.numpy as jnp
import numpy as np

import treex as tx


class NoisyStatefulLinear(tx.Module):
    # tree parts are defined by treex annotations
    w: tx.Parameter
    b: tx.Parameter
    count: tx.State
    rng: tx.Rng

    # other annotations are possible but ignored by type
    name: str

    def __init__(self, din, dout, name="noisy_stateful_linear"):
        self.name = name

        # Initializers only expect RNG key
        self.w = tx.Initializer(lambda k: jax.random.uniform(k, shape=(din, dout)))
        self.b = tx.Initializer(lambda k: jax.random.uniform(k, shape=(dout,)))

        # random state is JUST state, we can keep it locally
        self.rng = tx.Initializer(lambda k: k)

        # if value is known there is no need for an Initiaizer
        self.count = jnp.array(1)

    def __call__(self, x: np.ndarray) -> np.ndarray:
        assert isinstance(self.count, jnp.ndarray)
        assert isinstance(self.rng, jnp.ndarray)

        # state can easily be updated
        self.count = self.count + 1

        # random state is no different :)
        key, self.rng = jax.random.split(self.rng, 2)

        # your typical linear operation
        y = jnp.dot(x, self.w) + self.b

        # add noise for fun
        state_noise = 1.0 / self.count
        random_noise = 0.8 * jax.random.normal(key, shape=y.shape)

        return y + state_noise + random_noise

    def __repr__(self) -> str:
        return f"NoisyStatefulLinear(w={self.w}, b={self.b}, count={self.count}, rng={self.rng})"


linear = NoisyStatefulLinear(1, 1)

linear
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)





NoisyStatefulLinear(w=Initializer, b=Initializer, count=1, rng=Initializer)

Initialization

As advertised, initialization is easy, the only thing you need to do is to call init on your module with a random key:

import jax

linear = linear.init(key=jax.random.PRNGKey(42))
linear
NoisyStatefulLinear(w=[[0.91457367]], b=[0.42094743], count=1, rng=[1371681402 3011037117])

Modules are Pytrees

Its fundamentally important that modules are also Pytrees, we can check that they are by using tree_map with an arbitrary function:

# its a pytree alright
doubled = jax.tree_map(lambda x: 2 * x, linear)
doubled
NoisyStatefulLinear(w=[[1.8291473]], b=[0.84189487], count=2, rng=[2743362804 1727106938])

Modules can be sliced

An important feature of this Module system is that it can be sliced based on the type of its parameters, the slice method does exactly that:

params = linear.slice(tx.Parameter)
states = linear.slice(tx.State)

print(f"{params=}")
print(f"{states=}")
params=NoisyStatefulLinear(w=[[0.91457367]], b=[0.42094743], count=Nothing, rng=Nothing)
states=NoisyStatefulLinear(w=Nothing, b=Nothing, count=1, rng=[1371681402 3011037117])

Notice the following:

  • Both params and states are NoisyStatefulLinear objects, their type doesn't change after being sliced.
  • The fields that are filtered out by the slice on each field get a special value of type tx.Nothing.

Why is this important? As we will see later, it is useful keep parameters and state separate as they will crusially flow though different parts of value_and_grad.

Modules can be merged

This is just the inverse operation to slice, merge behaves like dict's update but returns a new module leaving the original modules intact:

linear = params.merge(states)
linear
NoisyStatefulLinear(w=[[0.91457367]], b=[0.42094743], count=1, rng=[1371681402 3011037117])

Modules compose

As you'd expect, you can have modules inside ther modules, same as previously the key is to annotate the class fields. Here we will create an MLP class that uses two NoisyStatefulLinear modules:

class MLP(tx.Module):
    linear1: NoisyStatefulLinear
    linear2: NoisyStatefulLinear

    def __init__(self, din, dmid, dout):
        self.linear1 = NoisyStatefulLinear(din, dmid, name="linear1")
        self.linear2 = NoisyStatefulLinear(dmid, dout, name="linear2")

    def __call__(self, x: np.ndarray) -> np.ndarray:
        x = jax.nn.relu(self.linear1(x))
        x = self.linear2(x)
        return x

    def __repr__(self) -> str:
        return f"MLP(linear1={self.linear1}, linear2={self.linear2})"


model = MLP(din=1, dmid=2, dout=1).init(key=42)
model
MLP(linear1=NoisyStatefulLinear(w=[[0.95598125 0.4032725 ]], b=[0.5371039  0.10409856], count=1, rng=[1371681402 3011037117]), linear2=NoisyStatefulLinear(w=[[0.7236692]
 [0.8625636]], b=[0.5354074], count=1, rng=[3818536016 1640990408]))

Full Example

Using the previous model we will show how to train it using the proposed Module system. First lets get some data:

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)


def get_data(dataset_size: int) -> Tuple[np.ndarray, np.ndarray]:
    x = np.random.normal(size=(dataset_size, 1))
    y = 5 * x - 2 + 0.4 * np.random.normal(size=(dataset_size, 1))
    return x, y


def get_batch(
    data: Tuple[np.ndarray, np.ndarray], batch_size: int
) -> Tuple[np.ndarray, np.ndarray]:
    idx = np.random.choice(len(data[0]), batch_size)
    return jax.tree_map(lambda x: x[idx], data)


data = get_data(1000)

plt.scatter(data[0], data[1])
plt.show()

png

Now we will be reusing the previous MLP model, and we will create an optax optimizer that will be used to train the model:

import optax

optimizer = optax.adam(1e-2)

params = model.slice(tx.Parameter)
states = model.slice(tx.State)

opt_state = optimizer.init(params)

Notice that we are already splitting the model into params and states since we need to pass the params only to the optimizer. Next we will create the loss function, it will take the model parts and the data parts and return the loss plus the new states:

from functools import partial


@partial(jax.value_and_grad, has_aux=True)
def loss_fn(params: MLP, states: MLP, x, y):

    # merge params and states to get a full model
    model: MLP = params.merge(states)

    # apply model
    pred_y = model(x)

    # MSE loss
    loss = jnp.mean((y - pred_y) ** 2)

    # new states
    states = model.slice(tx.State)

    return loss, states

Notice that the first thing we are doing is merging the params and states into the complete model since we need everything in place to perform the forward pass. Also, we return the updated states from the model, this is needed because JAX functional API requires us to be explicit about state management.

Note: inside loss_fn (which is wrapped by value_and_grad) module can behave like a regular mutable python object, however, every time its treated as pytree a new reference will be created as happens in jit, grad, vmap, etc. Its important to keep this into account when using functions like vmap inside a module as certain book keeping will be needed to manage state correctly.

Next we will implement the update function, it will look indistinguishable from your standard Haiku update which also separates weights into params and states:

@jax.jit
def update(params: MLP, states: MLP, opt_state, x, y):
    (loss, states), grads = loss_fn(params, states, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)

    # use regular optax
    params = optax.apply_updates(params, updates)

    return params, states, opt_state, loss

Finally we create a simple training loop that perform a few thousand updates and merge params and states back into a single model at the end:

steps = 10_000

for step in range(steps):
    x, y = get_batch(data, batch_size=32)

    params, states, opt_state, loss = update(params, states, opt_state, x, y)

    if step % 1000 == 0:
        print(f"[{step}] loss = {loss}")

# get the final model
model = params.merge(states)
[0] loss = 36.88694763183594
[1000] loss = 2.011059045791626
[2000] loss = 5.2326812744140625
[3000] loss = 1.7426897287368774
[4000] loss = 1.2130391597747803
[5000] loss = 1.6681632995605469
[6000] loss = 1.029949426651001
[7000] loss = 1.301844835281372
[8000] loss = 0.878564715385437
[9000] loss = 1.4557385444641113

Now lets generate some test data and see how our model performed:

import matplotlib.pyplot as plt

X_test = np.linspace(data[0].min(), data[0].max(), 100)[:, None]
y_pred = model(X_test)

plt.scatter(data[0], data[1], label="data", color="k")
plt.plot(X_test, y_pred, label="prediction")
plt.legend()
plt.show()

png

As you can see the model learned the general trend but because of the NoisyStatefulLinear modules we have a bit of noise in the predictions.