openxla/xla

Wrong output from JAX test after onednn change

Closed this issue · 4 comments

One of our JAX users reported a miscompilation in the following program when run on CPU:

import jax
import jax.numpy as jnp

w = jnp.array([1., 2., 3.], dtype=jnp.float32)
# w = jnp.array([10., 10., 10.], dtype=jnp.float32)  # this works
# w = jnp.array([1e-3, 1e-2, 1e-1], dtype=jnp.float32)  # this does not work
dim = len(w)
x = jnp.ones((dim,))
ys = jnp.ones((10**6, dim))
# ys = jnp.ones((10**5, dim)) # this works

def fun(x, y):
  diff = y - x
  out = jnp.dot(diff, w)
  # The following line should have no effect on the output or the derivatives
  # of this function. Removing it solves the bug.
  out = out + jnp.dot(jnp.dot(diff, jnp.zeros((dim, dim))), diff)
  return out

def mean_fun(x):
  outs = jax.vmap(lambda yy: fun(x, yy))(ys)
  return jnp.mean(outs)

jitted_grad = jax.jit(jax.grad(mean_fun))(x)
grad = jax.grad(mean_fun)(x)
print(f'Difference grads: {grad-jitted_grad}')
print(f'All close: {jnp.sum(jnp.abs(grad-jitted_grad)) < 1e-3}')
# Difference grads: [ 0.        -1.0000002 -2.0000007]
# All close: False

If working correctly, it should print:

Difference grads: [0. 0. 0.]
All close: True

I bisected the problem to XLA commit:
0dc5563

@jmaksymc

HLO dumps from that commit (bad) and the immediately preceding commit (good): dumps.tgz

We should either fix or revert that commit.

Also cc: @TensorFlow-MKL @agramesh1
Reverting isn't straightforward because there have been more related changes after that commit. I tried disabling just the MatMul + Mul rewrite but the error still persists. Could you please help take a look?

@hawkinsp @penpornk Thanks for reporting this issue. We are working on it.

@hawkinsp @penpornk A PR fixing the bug has been submitted (#13301)

PR #13301 has been merged. I have verified that the latest XLA commit (7d12719) produced the expected output. (Using nightly jax, e.g., 0.4.29.dev20240603, because of the recent ducc fft removal from jaxlib.)