/flax

Flax is a neural network library for JAX that is designed for flexibility.

Primary LanguagePythonApache License 2.0Apache-2.0

Flax: A neural network ecosystem for JAX designed for flexibility

Overview | Quickstart | Trying Flax | Installation | Full documentation

coverage

NOTE: Flax is in use by a growing community of researchers and engineers at Google who happily use Flax for their daily research. The new Flax "Linen" module API is now stable and we recommend it for all new projects. The old flax.nn API will be deprecated. Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!

Expect changes to the API, but we'll use deprecation warnings when we can, and keep track of them in our Changelog.

In case you need to reach us directly, we're at flax-dev@google.com.

Background: JAX

JAX is NumPy + autodiff + GPU/TPU

It allows for fast scientific computing and machine learning with the normal NumPy API (+ additional APIs for special accelerator ops when needed)

JAX comes with powerful primitives, which you can compose arbitrarily:

  • Autodiff (jax.grad): Efficient any-order gradients w.r.t any variables
  • JIT compilation (jax.jit): Trace any function ⟶ fused accelerator ops
  • Vectorization (jax.vmap): Automatically batch code written for individual samples
  • Parallelization (jax.pmap): Automatically parallelize code across multiple accelerators (including across hosts, e.g. for TPU pods)

Overview

Flax is a high-performance neural network library for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

  • Neural network API (flax.linen): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Optimizers (flax.optim): SGD, Momentum, Adam, LARS, Adagrad, LAMB, RMSprop

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b

Trying Flax

We keep here a limited list of canonical examples maintained by the Flax team that you can fork to get started. If you are looking for more examples, or others built by the community, please check the linen_examples folder.

Image classification

Reinforcement Learning

Natural language processing

Generative models

What does Flax look like?

We provide here two examples using the Flax API: a simple multi-layer perceptron and a CNN. To learn more about the Module abstraction, please check our docs.

class SimpleMLP(nn.Module):
  """ A MLP model """
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat)(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

Installation

You will need Python 3.6 or later.

For GPU support, first install jaxlib; please follow the instructions in the JAX readme. If they are not already installed, you will need to install CUDA and CuDNN runtimes.

Then install flax from PyPi:

> pip install flax

TPU support

We currently have a LM1b/Wikitext-2 language model with a Transformer architecture that's been tuned. You can run it directly via Colab.

At present, Cloud TPUs are network-attached, and Flax users typically feed in data from one or more additional VMs

When working with large-scale input data, it is important to create large enough VMs with sufficient network bandwidth to avoid having the TPUs bottlenecked waiting for input

TODO: Add an example for running on Google Cloud.

Getting involved

Currently, you need to install Python 3.6 for developing Flax, and svn for running the run_all_tests.sh script. After installing these prerequisites, you can clone the repository, set up your local environment, and run all tests with the following commands:

git clone https://github.com/google/flax
cd flax
python3.6 -m virtualenv env
. env/bin/activate
pip install -e . .[testing]
./tests/run_all_tests.sh

Alternatively, you can also develop inside a Docker container : See dev/README.md.

We welcome pull requests, in particular for those issues marked as PR-ready. For other proposals, we ask that you first open an Issue to discuss your planned contribution.

Note

This is not an official Google product.