
JAX implementation of softclip

Primary LanguagePythonMIT LicenseMIT


PyPI - Python Version PyPI

Simple JAX implementation of softclip, inspired by tensorflow probability

softclip is a differentiable bijector from the real number space to some interval. This is useful when you want to optimize a parameter that is assumed to be inside the interval [low, high].


softclip can be installed with pip directly from GitHub, with the following command:

python -m pip install softclip


The forward method is the function from the real number space to the interval [low, high]. The inverse method is the function from the interval [low, high] to the real number space, and is the inverse function of forward.

from softclip import SoftClip

bij = SoftClip(low=1.0, high=3.0, hinge_softness=0.5)
y = bij.forward(2.0) # y = 2.9640274
bij.inverse(y) # 1.9999975 ≒ 2.0

Simply set low=0.0 or high=0.0 to create a bijector to a positive/negative number domain.

bij_positive = SoftClip(low=0.0)
bij_negative = SoftClip(high=0.0)

By transforming softclip to distrax with to_distrax, you can create distrax bijectors:

bij_distrax = bij.to_distrax()