/elegy

Elegy is a Neural Networks framework based on Jax and Haiku.

Primary LanguagePythonApache License 2.0Apache-2.0

Elegy

Python 3.6 Python 3.7 Python 3.8 Release v0.1.3 Contributions welcome

Elegy is a Neural Networks framework based on Jax and Haiku.

Elegy implements the Keras API but makes changes to play better with Jax & Haiku and give more flexibility around losses and metrics (more on this soon). Elegy is still in a very early stage, feel free to test it and send us your feedback!

Main Features

  • Familiar: Elegy should feel very familiar to Keras users.
  • Flexible: Elegy improves upon the basic Keras API by letting users optionally take more control over the definition of losses and metrics.
  • Easy-to-use: Elegy maintains all the simplicity and ease of use that Keras brings with it.
  • Compatible: Elegy strives to be compatible with the rest of the Jax and Haiku ecosystem.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

Quick Start

Elegy greatly simplifies the training of Deep Learning models compared to pure Jax / Haiku where, due to Jax functional nature, users have to do a lot of book keeping around the state of the model. In Elegy just you just have to follow 3 basic steps:

1. Define the architecture inside an elegy.Module:

class MLP(elegy.Module):
    def call(self, image: jnp.ndarray) -> jnp.ndarray:
        mlp = hk.Sequential([
            hk.Flatten(),
            hk.Linear(300),
            jax.nn.relu,
            hk.Linear(10),
        ])
        return mlp(image)

2. Create a Model from this module and specify additional things like losses, metrics, optimizers, and callbacks:

model = elegy.Model(
    module=MLP.defer(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy.defer(),
    optimizer=optix.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
)

And you are done! For a more information checkout:

Why Jax + Haiku?

Jax is a linear algebra library with the perfect recipe:

  • Numpy's familiar API
  • The speed and hardware support of XLA
  • Automatic Differentiation

The awesome thing about Jax that Deep Learning is just a usecase that it happens to excel at but you can use it for most task you would use Numpy for.

On the other hand, Haiku is a Neural Networks library built on top of Jax that implements a Module system, common Neural Network layers, and even some full architectures. Compared to other Jax-based libraries like Trax or Flax, Haiku is very minimal, polished, well documented, and makes it super easy / clean to implement Deep Learning code!

We believe that Elegy can offer the best experience for coding Deep Learning applications by leveraging the power and familiarity of Jax API, the ease-of-use of Haiku's Module system, and packaging everything on top of a convenient Keras-like API.

Features

  • Model estimator class
  • losses module
  • metrics module
  • regularizers module
  • callbacks module
  • nn layers module

For more information checkout the Reference API section in the Documentation.

Contributing

Deep Learning is evolving at an incredible rate, there is so much to do and so few hands. If you wish to contibute anything from a loss or metrics to a new awesome feature for Elegy just open an issue or send a PR!

About Us

We are a couple friends passionate about ML.

License

Apache