jax-ml/jax

array.at.set is incredibly slow for complex128 dtype

Opened this issue · 5 comments

Description

For some reason array.at.set is incredibly slow with complex128 datatypes. Here I show it is much faster to split the arrays into real and imaginary parts before calling array.at.set and then recombine them into a complex array afterwards.

from jax import numpy as jnp
from time import time
import jax
import os

jax.config.update("jax_enable_x64", True)
@jax.jit
def set(x,x2,inds):
  return x.at[inds].set(x2)

@jax.jit
def complex_set(x,x2,inds):
  return jax.lax.complex(x.real.at[inds].set(x2.real), x.imag.at[inds].set(x2.imag))

x = jnp.zeros([10000000],dtype=jnp.complex128)
x2 = jnp.zeros([10000],dtype=jnp.complex128)
inds = jnp.arange(10000)

set(x,x2,inds)
complex_set(x,x2,inds)

t = time()
jax.block_until_ready(set(x,x2,inds))
print('set time=', time()-t)

t = time()
jax.block_until_ready(complex_set(x,x2,inds))
print('complex set time=', time()-t)

set time= 0.07047343254089355
complex set time= 0.0006287097930908203

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: NVIDIA H100 PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='workergpu158', release='6.1.97.1.fi', version='#1 SMP Tue Jul 9 06:21:23 EDT 2024', machine='x86_64')

As it happens we have workaround in JAX to avoid this slow behavior for scatter-add and scatter-sub, but not scatter-update. It should be pretty easy to make it work for scatter-update as well.

(The issue is that 128-bit scatters are currently expensive in XLA, because NVIDIA GPUs don't have a 16-byte atomic write operation.)

Actually, thinking about this a bit more, it's somewhat problematic to split into real and imaginary parts.

If there are multiple updates to the same index, then it's unspecified which update "wins". If we performed updates to both real and imaginary parts separately, you might get the real part of one and the imaginary part of another. Only if you promised us the indices are non-overlapping would it be safe for us to do that. Is that true in your case?

It's easier for add and sub because those are associative operations; we can apply the updates in any order and still get the same result, up to floating point error.

I see the issue. Yes, in our case the indices are non-overlapping so these functions are strictly the same.

Maybe the solution is to provide a warning about how scatter-update is slow with complex128 dtype and suggest updating the real and imaginary parts separately?

Can you try specifying unique_indices=True as an argument to set?

https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

That may well fix the problem.

It doesn't fix the issue.

from jax import numpy as jnp
from time import time
import jax
import os

jax.config.update("jax_enable_x64", True)
@jax.jit
def set(x,x2,inds):
  return x.at[inds].set(x2,unique_indices=True)

@jax.jit
def complex_set(x,x2,inds):
  return jax.lax.complex(x.real.at[inds].set(x2.real), x.imag.at[inds].set(x2.imag))

x = jnp.zeros([10000000],dtype=jnp.complex128)
x2 = jnp.zeros([10000],dtype=jnp.complex128)
inds = jnp.arange(10000)

set(x,x2,inds)
complex_set(x,x2,inds)

t = time()
jax.block_until_ready(set(x,x2,inds))
print('set time=', time()-t)

t = time()
jax.block_until_ready(complex_set(x,x2,inds))
print('complex set time=', time()-t)

set time= 0.07504415512084961
complex set time= 0.0005309581756591797

I agree that this might be a natural way to implement the fix.