NOTE: Flax is being actively improved and has a growing community of researchers and engineers at Google who happily use Flax for their daily research. Flax is in "early release stage" -- if that's your style, now could be a good time to start using it. We want to smooth out any rough edges so please report any issues, questions or concerns as GitHub issues. Expect changes to the API, but we'll use deprecation warnings when we can, and keep track of them in our Changelog.
In case you need to reach us directly, we're at flax-dev@google.com.
⟶ Full documentation and API reference
⟶ Annotated full end-to-end MNIST example
⟶ The Flax Guide -- a guided walkthrough of the parts of Flax
JAX is NumPy + autodiff + GPU/TPU
It allows for fast scientific computing and machine learning with the normal NumPy API (+ additional APIs for special accelerator ops when needed)
JAX comes with powerful primitives, which you can compose arbitrarily:
- Autodiff (
jax.grad
): Efficient any-order gradients w.r.t any variables - JIT compilation (
jax.jit
): Trace any function ⟶ fused accelerator ops - Vectorization (
jax.vmap
): Automatically batch code written for individual samples - Parallelization (
jax.pmap
): Automatically parallelize code across multiple accelerators (including across hosts, e.g. for TPU pods)
Flax is a high-performance neural network library for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.
Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:
-
Common layers (
flax.nn
): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout -
Optimizers (
flax.optim
): SGD, Momentum, Adam, LARS -
Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device
-
Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging
-
HOWTO guides: diffs that add functionality to educational base examples
-
Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b
We keep here a limited list of canonical examples maintained by the Flax team. If you are looking for more examples, or others built by the community, please check the examples folder for further guidance.
⟶ MNIST (also see annotated version)
⟶ CIFAR-10 (Wide ResNet w/ and w/o Shake-Shake, PyramidNet w/ShakeDrop)
⟶ Sequence tagging on Universal Dependencies
⟶ LM1b language modeling (try on a TPU in Colab)
⟶ LSTM text classifier on SST-2
⟶ LSTM seq2seq on number addition
⟶ Semi-supervised node classification on Zachary's karate club
The core of Flax is the Module abstraction. Modules allow you to write parameterized functions just as if you were writing a normal numpy function with JAX. The Module API allows you to declare parameters and use them directly with the JAX APIs.
Modules are the one part of Flax with "magic" -- the magic is constrained, and enables a very ergonomic model construction style, where modules are defined in a single function with minimal boilerplate.
A few things to know about Modules:
-
Create a new module by subclassing
flax.nn.Module
and implementing theapply
method. -
Within
apply
, callself.param(name, shape, init_func)
to register a new parameter and returns its initial value. -
Apply submodules with
MySubModule(name=..., ...)
withinMyModule.apply
. Parameters ofMySubModule
are stored as a dictionary under the parametersMyModule
and accessible viaself.get_param(name=...)
. This appliesMySubmodule
once -- to re-use parameters, useModule.shared
-
MyModule.init(rng, ...)
is a pure function that callsapply
in "init mode" and returns a nested Python dict of initialized parameter values -
MyModule.call(params, ...)
is a pure function that callsapply
in "call mode" and returns the output of the module.
For example you can define a learned linear transformation as follows:
from flax import nn
import jax.numpy as jnp
class Linear(nn.Module):
def apply(self, x, num_features, kernel_init_fn):
input_features = x.shape[-1]
W = self.param('W', (input_features, num_features), kernel_init_fn)
return jnp.dot(x, W)
You can also use nn.module
as a function decorator to create a new module, as
long as you don't need access to self
for creating parameters directly:
@nn.module
def DenseLayer(x, features):
x = flax.nn.Dense(x, features)
x = flax.nn.relu(x)
return x
⟶ Read more about Modules in the Flax Guide
(from examples/imagenet/models.py)
class ResidualBlock(nn.Module):
def apply(self, x, filters, strides=(1, 1), train=True, dtype=jnp.float32):
needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
batch_norm = nn.BatchNorm.partial(
use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype)
conv = nn.Conv.partial(bias=False, dtype=dtype)
residual = x
if needs_projection:
residual = conv(residual, filters * 4, (1, 1), strides, name='proj_conv')
residual = batch_norm(residual, name='proj_bn')
y = conv(x, filters, (1, 1), name='conv1')
y = batch_norm(y, name='bn1')
y = nn.relu(y)
y = conv(y, filters, (3, 3), strides, name='conv2')
y = batch_norm(y, name='bn2')
y = nn.relu(y)
y = conv(y, filters * 4, (1, 1), name='conv3')
y = batch_norm(y, name='bn3', scale_init=nn.initializers.zeros)
y = nn.relu(residual + y)
return y
class ResNet(nn.Module):
def apply(self, x, num_classes, num_filters=64, num_layers=50,
train=True, dtype=jnp.float32):
if num_layers not in _block_size_options:
raise ValueError('Please provide a valid number of layers')
block_sizes = _block_size_options[num_layers]
x = nn.Conv(
x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
bias=False, dtype=dtype, name='init_conv')
x = nn.BatchNorm(
x, use_running_average=not train, momentum=0.9,
epsilon=1e-5, dtype=dtype, name='init_bn')
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
for i, block_size in enumerate(block_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = ResidualBlock(
x, num_filters * 2 ** i, strides=strides,
train=train, dtype=dtype)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(x, num_classes)
x = nn.log_softmax(x)
return x
You will need Python 3.6 or later.
For GPU support, first install jaxlib
; please follow the
instructions in the JAX
readme. If they
are not already installed, you will need to install
CUDA and
CuDNN runtimes.
Then install flax
from PyPi:
> pip install flax
We currently have a LM1b/Wikitext-2 language model with a Transformer architecture that's been tuned. You can run it directly via Colab.
At present, Cloud TPUs are network-attached, and Flax users typically feed in data from one or more additional VMs
When working with large-scale input data, it is important to create large enough VMs with sufficient network bandwidth to avoid having the TPUs bottlenecked waiting for input
TODO: Add an example for running on Google Cloud.
We welcome pull requests, in particular for those issues marked as PR-ready. For other proposals, we ask that you first open an Issue to discuss your planned contribution.
This is not an official Google product.