We developed a quantative finance library based on JAX. This library provides high-performance financial computation leveraging the hardware acceleration, parallel scientific computing and automatic differentiation of JAX.
Finax is still a research project that is currently under development. Some APIs may change in the future. We welcome any suggestions and contributions.
We can:
-
run financial workloads on CPU/GPU/TPU with XLA acceleration
-
calculate mathematical derivative of financial models, i.e. Greeks
-
distribute workloads on multiple devices
examples
directory contains several demonstrations of using our Finax library.
You must first follow JAX's installation guide to install JAX based on your device architecture (CPU/GPU/TPU).
pip install finax --upgrade
Here is an example for option pricing using Black-Scholes model.
import numpy as np
from jax import jit
from jax.config import config
config.update("jax_enable_x64", True)
from finax.black_sholes.vanilla_prices import option_price
option_price_fn = jit(option_price)
dtype = jnp.float64
forwards = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=dtype)
strikes = jnp.array([3.0, 3.0, 3.0, 3.0, 3.0], dtype=dtype)
volatilities = jnp.array([0.0001, 102.0, 2.0, 0.1, 0.4], dtype=dtype)
expiries = jnp.array(1.0, dtype=dtype)
computed_prices = option_price_fn(
volatilities=volatilities,
strikes=strikes,
expiries=expiries,
forwards=forwards)
To enable 64-bit precision, set the respective JAX flag before importing finax
(see the JAX guide), for example:
from jax.config import config
config.update("jax_enable_x64", True)