Main features:
- Modules contain their parameters
- Easy transfer learning
- Simple initialization
- No metaclass magic
- No apply method
- No need special versions of
vmap
,jit
, and friends.
To prove the previous we will start with by creating a very contrived but complete module which will use everything from parameters, states, and random state:
from typing import Tuple
import jax.numpy as jnp
import numpy as np
import treex as tx
class NoisyStatefulLinear(tx.Module):
# tree parts are defined by treex annotations
w: tx.Parameter
b: tx.Parameter
count: tx.State
rng: tx.Rng
# other annotations are possible but ignored by type
name: str
def __init__(self, din, dout, name="noisy_stateful_linear"):
self.name = name
# Initializers only expect RNG key
self.w = tx.Initializer(lambda k: jax.random.uniform(k, shape=(din, dout)))
self.b = tx.Initializer(lambda k: jax.random.uniform(k, shape=(dout,)))
# random state is JUST state, we can keep it locally
self.rng = tx.Initializer(lambda k: k)
# if value is known there is no need for an Initiaizer
self.count = jnp.array(1)
def __call__(self, x: np.ndarray) -> np.ndarray:
assert isinstance(self.count, jnp.ndarray)
assert isinstance(self.rng, jnp.ndarray)
# state can easily be updated
self.count = self.count + 1
# random state is no different :)
key, self.rng = jax.random.split(self.rng, 2)
# your typical linear operation
y = jnp.dot(x, self.w) + self.b
# add noise for fun
state_noise = 1.0 / self.count
random_noise = 0.8 * jax.random.normal(key, shape=y.shape)
return y + state_noise + random_noise
def __repr__(self) -> str:
return f"NoisyStatefulLinear(w={self.w}, b={self.b}, count={self.count}, rng={self.rng})"
linear = NoisyStatefulLinear(1, 1)
linear
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
NoisyStatefulLinear(w=Initializer, b=Initializer, count=1, rng=Initializer)
As advertised, initialization is easy, the only thing you need to do is to call init
on your module with a random key:
import jax
linear = linear.init(key=jax.random.PRNGKey(42))
linear
NoisyStatefulLinear(w=[[0.91457367]], b=[0.42094743], count=1, rng=[1371681402 3011037117])
Its fundamentally important that modules are also Pytrees, we can check that they are by using tree_map
with an arbitrary function:
# its a pytree alright
doubled = jax.tree_map(lambda x: 2 * x, linear)
doubled
NoisyStatefulLinear(w=[[1.8291473]], b=[0.84189487], count=2, rng=[2743362804 1727106938])
An important feature of this Module system is that it can be sliced based on the type of its parameters, the slice
method does exactly that:
params = linear.slice(tx.Parameter)
states = linear.slice(tx.State)
print(f"{params=}")
print(f"{states=}")
params=NoisyStatefulLinear(w=[[0.91457367]], b=[0.42094743], count=Nothing, rng=Nothing)
states=NoisyStatefulLinear(w=Nothing, b=Nothing, count=1, rng=[1371681402 3011037117])
Notice the following:
- Both
params
andstates
areNoisyStatefulLinear
objects, their type doesn't change after being sliced. - The fields that are filtered out by the
slice
on each field get a special value of typetx.Nothing
.
Why is this important? As we will see later, it is useful keep parameters and state separate as they will crusially flow though different parts of value_and_grad
.
This is just the inverse operation to slice
, merge
behaves like dict's update
but returns a new module leaving the original modules intact:
linear = params.merge(states)
linear
NoisyStatefulLinear(w=[[0.91457367]], b=[0.42094743], count=1, rng=[1371681402 3011037117])
As you'd expect, you can have modules inside ther modules, same as previously the key is to annotate the class fields. Here we will create an MLP
class that uses two NoisyStatefulLinear
modules:
class MLP(tx.Module):
linear1: NoisyStatefulLinear
linear2: NoisyStatefulLinear
def __init__(self, din, dmid, dout):
self.linear1 = NoisyStatefulLinear(din, dmid, name="linear1")
self.linear2 = NoisyStatefulLinear(dmid, dout, name="linear2")
def __call__(self, x: np.ndarray) -> np.ndarray:
x = jax.nn.relu(self.linear1(x))
x = self.linear2(x)
return x
def __repr__(self) -> str:
return f"MLP(linear1={self.linear1}, linear2={self.linear2})"
model = MLP(din=1, dmid=2, dout=1).init(key=42)
model
MLP(linear1=NoisyStatefulLinear(w=[[0.95598125 0.4032725 ]], b=[0.5371039 0.10409856], count=1, rng=[1371681402 3011037117]), linear2=NoisyStatefulLinear(w=[[0.7236692]
[0.8625636]], b=[0.5354074], count=1, rng=[3818536016 1640990408]))
Using the previous model
we will show how to train it using the proposed Module system. First lets get some data:
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
def get_data(dataset_size: int) -> Tuple[np.ndarray, np.ndarray]:
x = np.random.normal(size=(dataset_size, 1))
y = 5 * x - 2 + 0.4 * np.random.normal(size=(dataset_size, 1))
return x, y
def get_batch(
data: Tuple[np.ndarray, np.ndarray], batch_size: int
) -> Tuple[np.ndarray, np.ndarray]:
idx = np.random.choice(len(data[0]), batch_size)
return jax.tree_map(lambda x: x[idx], data)
data = get_data(1000)
plt.scatter(data[0], data[1])
plt.show()
Now we will be reusing the previous MLP model, and we will create an optax optimizer that will be used to train the model:
import optax
optimizer = optax.adam(1e-2)
params = model.slice(tx.Parameter)
states = model.slice(tx.State)
opt_state = optimizer.init(params)
Notice that we are already splitting the model into params
and states
since we need to pass the params
only to the optimizer. Next we will create the loss function, it will take the model parts and the data parts and return the loss plus the new states:
from functools import partial
@partial(jax.value_and_grad, has_aux=True)
def loss_fn(params: MLP, states: MLP, x, y):
# merge params and states to get a full model
model: MLP = params.merge(states)
# apply model
pred_y = model(x)
# MSE loss
loss = jnp.mean((y - pred_y) ** 2)
# new states
states = model.slice(tx.State)
return loss, states
Notice that the first thing we are doing is merging the params
and states
into the complete model since we need everything in place to perform the forward pass. Also, we return the updated states from the model, this is needed because JAX functional API requires us to be explicit about state management.
Note: inside loss_fn
(which is wrapped by value_and_grad
) module can behave like a regular mutable python object, however, every time its treated as pytree a new reference will be created as happens in jit
, grad
, vmap
, etc. Its important to keep this into account when using functions like vmap
inside a module as certain book keeping will be needed to manage state correctly.
Next we will implement the update
function, it will look indistinguishable from your standard Haiku update which also separates weights into params
and states
:
@jax.jit
def update(params: MLP, states: MLP, opt_state, x, y):
(loss, states), grads = loss_fn(params, states, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
# use regular optax
params = optax.apply_updates(params, updates)
return params, states, opt_state, loss
Finally we create a simple training loop that perform a few thousand updates and merge params
and states
back into a single model
at the end:
steps = 10_000
for step in range(steps):
x, y = get_batch(data, batch_size=32)
params, states, opt_state, loss = update(params, states, opt_state, x, y)
if step % 1000 == 0:
print(f"[{step}] loss = {loss}")
# get the final model
model = params.merge(states)
[0] loss = 36.88694763183594
[1000] loss = 2.011059045791626
[2000] loss = 5.2326812744140625
[3000] loss = 1.7426897287368774
[4000] loss = 1.2130391597747803
[5000] loss = 1.6681632995605469
[6000] loss = 1.029949426651001
[7000] loss = 1.301844835281372
[8000] loss = 0.878564715385437
[9000] loss = 1.4557385444641113
Now lets generate some test data and see how our model performed:
import matplotlib.pyplot as plt
X_test = np.linspace(data[0].min(), data[0].max(), 100)[:, None]
y_pred = model(X_test)
plt.scatter(data[0], data[1], label="data", color="k")
plt.plot(X_test, y_pred, label="prediction")
plt.legend()
plt.show()
As you can see the model learned the general trend but because of the NoisyStatefulLinear
modules we have a bit of noise in the predictions.