/laplax

Laplace approximations in JAX.

Primary LanguagePythonApache License 2.0Apache-2.0

Laplax Logo

Laplax

What is laplax?

The laplax package aims to provide a performant, minimal, and practical implementation of Laplace approximation techniques in jax. This package is designed to support a wide range of scientific libraries, initially focusing on compatibility with popular neural network libraries such as equinox, flax.linen, and flax.nnx. Our goal is to create a flexible tool for both practical applications and research, enabling rapid iteration and comparison of new approaches.

Design Philosophy

The development of laplax is guided by the following principles:

  • Minimal Dependencies: The package only depends on jax, ensuring compatibility and ease of integration.

  • Matrix-Vector Product Focus: The core of our implementation revolves around efficient matrix-vector products. By passing around callables, we maintain a loose coupling between components, allowing for easy interaction with various other packages, including linear operator libraries in jax.

  • Performance and Practicality: We prioritize a performant and minimal implementation that serves practical needs. The package offers a simple API for basic use cases while primarily serving as a reference implementation for researchers to compare new methods or iterate quickly over experiments.

  • PyTree-Centric Structure: Internally, the package is structured around PyTrees. This design choice allows us to defer materialization until necessary, optimizing performance and memory usage.

Roadmap and Contributions

We're developing this package in public, and discussions about the roadmap and feature priorities are structured in the Issues section. If you're interested in contributing or want to see what's planned for the future, please check them out.