Qazalbash/jaxampler

Incorrect implementation of `rvs.Triangular.logcdf_x`

mahausmani opened this issue · 1 comments

Description

the log_cdf function of triangular distribution incorrectly calculates the cdf when x==self_mode:

@partial(jit, static_argnums=(0,))
def logcdf_x(self, x: Numeric) -> Numeric:
    conditions = [
        x < self._low,
        (self._low <= x) & (x < self._mode),
        x == self._mode,
        (self._mode < x) & (x < self._high),
        x >= self._high,
    ]
    choices = [
        -jnp.inf,
        2 * jnp.log(x - self._low) - jnp.log(self._high - self._low) - jnp.log(self._mode - self._low),
        jnp.log(0.5),
        jnp.log(1 - ((self._high - x) ** 2 / ((self._high - self._low) * (self._high - self._mode)))),
        jnp.log(1),
    ]
    return jnp.select(conditions, choices)

it should instead return:

2 * jnp.log(x - self._low) - jnp.log(self._high - self._low) - jnp.log(self._mode - self._low),

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

Linux

NVIDIA GPU info

No response

@mahausmani thanks for the report, it should be fixed by #61