mfschubert/sparsejac

CPU vs GPU performace snippet

agudym opened this issue · 5 comments

Hi there!

First of all thanks for the brilliant library! Imho JAX+Sparse math is a real game changer as it's a LOT of practical tasks(e.g. finit elements methods) where it's not feasible to work with dense mats...

I would like to contrubute the following examples of your library, mb open a PR:

N = 10000
print(f"Jacobian {N} x {N} (DENSE) size: {N * N * 8 / (1<<30):.3f} GiB")
indices = np.array((range(N), range(N))).T
data = np.ones(N)
fn = lambda x: x**4

for i, device in enumerate(("CPU", "GPU")):
    print(f"*** {i+1}. Processing with {device} ***")
    with jax.default_device(jax.devices(device.lower())[0]):
        x_rnd = jax.random.uniform(jax.random.PRNGKey(0), shape=(N,))
        
        # simplified a bit: removed ensure_compile_time_eval
        # sparsity is generated explicitly, not from dense mat as it's more common (imho)
        sparsity = jax.experimental.sparse.BCOO((data, indices), shape=(N,N)) 

        sparse_jacrev_fn = jax.jit(sparsejac.jacrev(fn, sparsity=sparsity))
        dense_jacrev_fn = jax.jit(jax.jacrev(fn))

        assert jnp.all(sparse_jacrev_fn(x_rnd).todense() == dense_jacrev_fn(x_rnd))

        for jacrev_fn, desc in zip((dense_jacrev_fn, sparse_jacrev_fn), ("DENSE", "SPARSE")):
            print(f"Jacobian evaluation {desc}:")
            %timeit jacrev_fn(x_rnd).block_until_ready()
# Jacobian 10000 x 10000 (DENSE) size: 0.745 GiB
# *** 1. Processing with CPU ***
# Jacobian evaluation DENSE:
# 240 ms ± 9.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Jacobian evaluation SPARSE:
# 50.6 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# *** 2. Processing with GPU ***
# Jacobian evaluation DENSE:
# 10.8 ms ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Jacobian evaluation SPARSE:
# 105 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

It's interesting that even for this suuper sparse Jacobian GPU implementation is 2x SLOWER then CPU. Is it worth to open jax-bug btw ?

And second part - using sparse computations in the optimization with great speedup:

# Non-linear Least squares example (x**4 is used), N = 1000
tol = 1e-10
def scipy_jacrev_fn(x):
    J_sparse = sparse_jacrev_fn(x)
    return scipy.sparse.coo_array((J_sparse.data, J_sparse.indices.T))

print("Jacobian computation with SCIPY-sparse conversion:")
%timeit scipy_jacrev_fn(x_rnd)

for i, desc in enumerate(("DENSE", "SPARSE")):
    start_time = time.time()
    res = scipy.optimize.least_squares(
        jax.jit(fn),
        x0=x_rnd,
        jac= dense_jacrev_fn if i == 0 else scipy_jacrev_fn,
        ftol=tol, xtol=tol, gtol=tol
    )
    print(f"{desc} Non-Linear Least Squares {res.njev} iterations done in {time.time() - start_time:3.3f} sec")
    assert jnp.allclose(res.fun, 0, atol=1e-6)

#Jacobian computation with SCIPY-sparse conversion:
#2.41 ms ± 959 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
#DENSE Non-Linear Least Squares 14 iterations done in 13.038 sec
#SPARSE Non-Linear Least Squares 14 iterations done in 0.361 sec

I wonder if jaxopt will support sparse matrices...

Thanks!

After solving #11 imho it would be reasonable to switch jacrev -> jacfwd as it's generally faster if rows >= cols

Thanks for your comment. I would be open to a new example if it demonstrates something substantially new. With respect to performance and jacfwd vs. jacrev, a comment in the readme might suffice?

Thanks for fast reaction.

With respect to performance and jacfwd vs. jacrev, a comment in the readme might suffice?

Yeah, that's might help engineers who is not aware of "jacobian construction mechanics" (which is not needed to the lib-users imho) to get muuch more effective code.
Readme update works! Maybe a warning message or even an assert will do even better...

if it demonstrates something substantially new

The key(yet small) difference I'm suggesting is sparse construction

sparsity = jax.experimental.sparse.BCOO((data, indices), shape=(N,N))

which allows to work with really big Ns - where dense math fails.
I can adjust the snippet to gradually increase N up to the CPU-GPU memory allocation failure.
However feel free to ignore the proposal.

I added some comments to the readme. Based on some quick benchmarking, jacrev and jacfwd perform similarly on CPU. In general I suggest users test both functions to see what works best on their hardware and their problem. This is now stated in the readme.

Here is the colab benchmark I put together:
https://colab.research.google.com/drive/1h1AV2Y_L6zWDprpIKNsi6KEu0QgQimSM#scrollTo=Xty93PY6Qtdy

Yep, I've checked the colab, looks like that! For sqared matrices there seems to be no difference.

p.s. unless we have tall Jacobians where FWD could be so muuch faster...