Wrong output from JAX test after onednn change
Closed this issue · 4 comments
hawkinsp commented
One of our JAX users reported a miscompilation in the following program when run on CPU:
import jax
import jax.numpy as jnp
w = jnp.array([1., 2., 3.], dtype=jnp.float32)
# w = jnp.array([10., 10., 10.], dtype=jnp.float32) # this works
# w = jnp.array([1e-3, 1e-2, 1e-1], dtype=jnp.float32) # this does not work
dim = len(w)
x = jnp.ones((dim,))
ys = jnp.ones((10**6, dim))
# ys = jnp.ones((10**5, dim)) # this works
def fun(x, y):
diff = y - x
out = jnp.dot(diff, w)
# The following line should have no effect on the output or the derivatives
# of this function. Removing it solves the bug.
out = out + jnp.dot(jnp.dot(diff, jnp.zeros((dim, dim))), diff)
return out
def mean_fun(x):
outs = jax.vmap(lambda yy: fun(x, yy))(ys)
return jnp.mean(outs)
jitted_grad = jax.jit(jax.grad(mean_fun))(x)
grad = jax.grad(mean_fun)(x)
print(f'Difference grads: {grad-jitted_grad}')
print(f'All close: {jnp.sum(jnp.abs(grad-jitted_grad)) < 1e-3}')
# Difference grads: [ 0. -1.0000002 -2.0000007]
# All close: False
If working correctly, it should print:
Difference grads: [0. 0. 0.]
All close: True
I bisected the problem to XLA commit:
0dc5563
HLO dumps from that commit (bad) and the immediately preceding commit (good): dumps.tgz
We should either fix or revert that commit.
penpornk commented
Also cc: @TensorFlow-MKL @agramesh1
Reverting isn't straightforward because there have been more related changes after that commit. I tried disabling just the MatMul + Mul rewrite but the error still persists. Could you please help take a look?