/JAXtronomy

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

Primary LanguagePythonBSD 3-Clause "New" or "Revised" LicenseBSD-3-Clause

JAXtronomy

https://codecov.io/gh/lenstronomy/JAXtronomy/graph/badge.svg?token=6EJAX8CF62 https://img.shields.io/pypi/v/jaxtronomy?label=PyPI&logo=pypi

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

Disclaimer: This project is still in an early development phase and serves as a skeleton for someone taking the lead on it :)

The goal of this library is to reimplement lenstronomy functionalities in pure JAX to allow for automatic differentiation, GPU acceleration, and batched computations.

Guiding Principles:

  • Strive to be a drop-in replacement for lenstronomy, i.e. provide a close match to the lenstronomy API.
  • Each function/feature will be tested against the reference lenstronomy implementation.
  • This package will aim to be a subset of lenstronomy (i.e. only contains functions with a reference lenstronomy implementation).
  • Implementations should be easy to read and understand.
  • Code should be pip installable on any machine, no compilation required.
  • Any notable differences between the JAX and reference implementations will be clearly documented.

Related software packages

The following lensing software packages do use JAX-accelerated computing that in part were inspired or made use of lenstronomy functions: