FloatingPointError in jax.scipy.stats
Closed this issue · 2 comments
Description
i got a FloatingPointError when using jax.scipy.stats.gamma.pdf, i've tried using jax.config.update("jax_enable_x64", True) to no avail.
code for reproduction
import jax
import jax.numpy as jnp
jax.config.update("jax_debug_nans", True)
def mle_loss(data, target):
'''
data: (batch_size, history_length, bar_length)
target: (batch_size, 1)
'''
weights = data[:, :, 0]
k = data[:, :, 1]
theta = data[:, :, 2]
probs = jax.vmap(jax.scipy.stats.gamma.pdf)(target, k, theta)
prob = jnp.sum(weights * probs, axis=1)
loss = -jax.lax.log(prob).mean()
loss = jax.lax.clamp(-1e6, loss, 1e6)
return loss
data = jnp.array([[[1. , 0.7096346 , 0.7514472 ]],
[[1. , 0.7194072 , 0.735364 ]],
[[1. , 0.7475644 , 0.7523259 ]],
[[1. , 0.7042354 , 0.7264852 ]],
[[1. , 0.47818542, 1.3681346 ]],
[[1. , 0.6943242 , 0.7313199 ]],
[[1. , 0.687601 , 0.81750506]],
[[1. , 0.72067565, 0.7784166 ]]])
target = jnp.array([1.0, 1.0004145, 0.99964607, 1.0004367, 1.000424, 0.99967706,
1.0009085 ,1.000621])
mle_loss(data, target)
the error
{
"name": "FloatingPointError",
"message": "invalid value (nan) encountered in jit(log). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.",
"stack": "---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
[... skipping hidden 1 frame]
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
332 with TraceAnnotation(name, **decorator_kwargs):
--> 333 return func(*args, **kwargs)
334 return wrapper
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1292, in ExecuteReplicated.__call__(self, *args)
1291 for arrays in out_arrays:
-> 1292 dispatch.check_special(self.name, arrays)
1293 out = self.out_handler(out_arrays)
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/dispatch.py:327, in check_special(name, bufs)
326 for buf in bufs:
--> 327 _check_special(name, buf.dtype, buf)
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/dispatch.py:332, in _check_special(name, dtype, buf)
331 if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> 332 raise FloatingPointError(f\"invalid value (nan) encountered in {name}\")
333 if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
FloatingPointError: invalid value (nan) encountered in jit(log)
During handling of the above exception, another exception occurred:
FloatingPointError Traceback (most recent call last)
Cell In[1], line 41
23 data = jnp.array([[[1. , 0.7096346 , 0.7514472 ]],
24
25 [[1. , 0.7194072 , 0.735364 ]],
(...)
36
37 [[1. , 0.72067565, 0.7784166 ]]])
38 target = jnp.array([1.0, 1.0004145, 0.99964607, 1.0004367, 1.000424, 0.99967706,
39 1.0009085 ,1.000621])
---> 41 mle_loss(data, target)
Cell In[1], line 15, in mle_loss(data, target)
12 k = data[:, :, 1]
13 theta = data[:, :, 2]
---> 15 probs = jax.vmap(jax.scipy.stats.gamma.pdf)(target, k, theta)
16 prob = jnp.sum(weights * probs, axis=1)
18 loss = -jax.lax.log(prob).mean()
[... skipping hidden 3 frame]
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/scipy/stats/gamma.py:92, in pdf(x, a, loc, scale)
62 def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
63 r\"\"\"Gamma probability distribution function.
64
65 JAX implementation of :obj:`scipy.stats.gamma` ``pdf``.
(...)
90 - :func:`jax.scipy.stats.gamma.logsf`
91 \"\"\"
---> 92 return lax.exp(logpdf(x, a, loc, scale))
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/scipy/stats/gamma.py:56, in logpdf(x, a, loc, scale)
54 one = _lax_const(x, 1)
55 y = lax.div(lax.sub(x, loc), scale)
---> 56 log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
57 shape_terms = lax.add(gammaln(a), lax.log(scale))
58 log_probs = lax.sub(log_linear_term, shape_terms)
[... skipping hidden 7 frame]
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/scipy/special.py:498, in xlogy(x, y)
496 x, y = promote_args_inexact(\"xlogy\", x, y)
497 x_ok = x != 0.
--> 498 return jnp.where(x_ok, lax.mul(x, lax.log(y)), jnp.zeros_like(x))
[... skipping hidden 17 frame]
File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/pjit.py:1692, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args)
1675 # If control reaches this line, we got a NaN on the output of `compiled`
1676 # but not `fun.call_wrapped` on the same arguments. Let's tell the user.
1677 msg = (f\"{str(e)}. Because \"
1678 \"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the \"
1679 \"de-optimized function (i.e., the function as if the `jit` \"
(...)
1690 \"If you see this error, consider opening a bug report at \"
1691 \"https://github.com/jax-ml/jax.\")
-> 1692 raise FloatingPointError(msg)
FloatingPointError: invalid value (nan) encountered in jit(log). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/jax-ml/jax."
}
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.35
numpy: 2.1.3
python: 3.12.1 (main, Sep 30 2024, 17:05:21) [GCC 9.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='codespaces-5a7c09', release='6.5.0-1025-azure', version='#26~22.04.1-Ubuntu SMP Thu Jul 11 22:33:04 UTC 2024', machine='x86_64')
Thanks for the report! The issue here occurs when loc > x
in the gamma logpdf
. This is checked here:
jax/jax/_src/scipy/stats/gamma.py
Line 59 in afdc792
so this will return -inf
in that case as expected, but the NaN pops up here:
jax/jax/_src/scipy/stats/gamma.py
Line 56 in afdc792
It's safe to set jax_debug_nans
to False
in this case (since it's checked later), or make the following workaround change to your code:
- theta = data[:, :, 2]
+ theta = jax.lax.clamp(-jnp.inf, data[:, :, 2], target[:, None])
But I think we should definitely fix these leaking NaNs in JAX itself! If you're keen to submit a PR, I'd be happy to help/point you in the right direction, otherwise I can probably fix it myself soon.
looking again at my code and the documentation i've noted i've misunderstood and misused jax.scipy.stats.gamma.pdf, so i don't think i'd be of much help