/jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

Primary LanguagePythonApache License 2.0Apache-2.0

JAXopt

Installation | Examples | References

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

Installation

JAXopt can be installed with pip directly from github, with the following command:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be be installed from sources with the following command:

$ python setup.py install

References

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.