jax-ml/jax

np.isnan doesn't work on CPU in fast-math mode

noahstier opened this issue ยท 14 comments

as a result np.nan_to_num, np.nanmean, etc all don't work

import jax.numpy as np
a = np.zeros(1) / np.zeros(1)
print(a.__array__())
print(np.isnan(a).__array__())

[nan]
[False]

This bug only happens with the CPU-only build of JAX, when I see this warning: "warnings.warn('No GPU found, falling back to CPU.')"

This happens because XLA's CPU backend defaults to enabling fast math mode, which does not preserve nan/inf semantics. The GPU backend does not. Note the comment here:
https://github.com/google/jax/blob/master/jax/numpy/lax_numpy.py#L699

# Caution: If fast math mode is enabled, the semantics of inf and nan are not
# preserved by XLA/LLVM, and the behavior of inf/nan values is unpredictable.
# To disable fast math mode on CPU, set the environment variable
# XLA_FLAGS=--xla_cpu_enable_fast_math=false

The XLA_FLAGS environment variable above makes your example pass.

I guess the important question is: should we disable fast math mode by default? Are exact NaN semantics important to you?

I think consistency between CPU and GPU is more important than performance in this case. There can still be a performance tips section that explains how to activate fast math.

I just got surprised by this, too. Maybe another option is to print a warning at startup, adding to the "No GPU found"?

A brief update on this bug: we tried disabling fastmath in XLA/CPU by default, but found it regressed performance for some neural network benchmarks significantly because it prevents vectorization in some important cases.

https://reviews.llvm.org/D57728 apparently fixes the performance problem, but it isn't in yet. I'm hoping we can simply disable fast math by default when that change makes it into LLVM.

A warning makes sense until we do so, I guess.

I also got surprised by this (I am using a CPU). Here is a simple example:

import numpy as onp # original numpy
import jax.numpy as np
print(np.isnan(np.nan)) #F
print(onp.isnan(np.nan)) #T
print(np.isnan(onp.nan)) #F
print(onp.isnan(onp.nan)) #T

Maybe worth mentioning the issue on the jax homepage (the comment is currently buried deep in the gotchas colab)

I also tried to set the environment flag but to no avail (is my syntax correct?)

import os
os.environ["XLA_FLAGS"]="--xla_cpu_enable_fast_math=false"

print(np.isnan(np.nan)) #F
print(onp.isnan(np.nan)) #T
print(np.isnan(onp.nan)) #F
print(onp.isnan(onp.nan)) #T

Did that os.environ come before importing anything from jax? That might be necessary.

Great idea re: mentioning it in the readme. I'll add it now.

yes, I did the os.environ thing first. I am running inside Spyder IDE.
Full script:

import os
os.environ["XLA_FLAGS"]="--xla_cpu_enable_fast_math=false"

import numpy as onp # original numpy
import jax.numpy as np

print(np.isnan(np.nan)) #F
print(onp.isnan(np.nan)) #T
print(np.isnan(onp.nan)) #F
print(onp.isnan(onp.nan)) #T

Thanks. Hrm I was unable to repro in my local environment (which I tried before my previous guess about os.environ going first):


In [1]: import os

In [2]: os.environ["XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false"

In [3]: import jax.numpy as np

In [4]: print(np.isnan(np.nan))
True

Not sure how to chase that down further. In any case, we'll fix CPU nan issues ASAP.

We just pushed out jaxlib 0.1.13, which should fix this problem.

Parts of fast math are still enabled by default for performance, but the semantics of NaNs and Infs should now be honored. Please file a new issue if you see any further problems!

This issue has been closed for a while now, so I don't know if anyone will see this. I just tripped over this while trying to implement high accuracy sum and dot products (see here) and finding that it doesn't work unless the fast-math flag referenced above is turned off. Taking a look at the LLVM fast-math options here, it's the reassoc flag that's the problem, as it allows unsafe changes to associativity in floating-point expressions. Is there any possibility of a finer grained approach to setting these flags? One might hope that the contract flag (possibly with a few others) would be enough to keep most of the performance benefits. Allowing associativity transformations makes the CPU and GPU execution different in a way that is surprising to users, and makes implementing a lot of backward-stable numerical algorithms impossible on the CPU.

Just a suggestion. Thanks for the work - jax has been an amazing platform to develop computational tools in!

Thanks for the input @btalami. What you're doing sounds similar to the precision doubling experiment I tried a while ago: #3465 I gave up on it primarily due to the difficulties you describe.