
CPU vs GPU performace snippet

First of all thanks for the brilliant library! Imho JAX+Sparse math is a real game changer as it's a LOT of practical tasks(e.g. finit elements methods) where it's not feasible to work with dense mats...

I would like to contrubute the following examples of your library, mb open a PR:

N = 10000
print(f"Jacobian {N} x {N} (DENSE) size: {N * N * 8 / (1<<30):.3f} GiB")
indices = np.array((range(N), range(N))).T
data = np.ones(N)
fn = lambda x: x**4

for i, device in enumerate(("CPU", "GPU")):
    print(f"*** {i+1}. Processing with {device} ***")
    with jax.default_device(jax.devices(device.lower())[0]):
        x_rnd = jax.random.uniform(jax.random.PRNGKey(0), shape=(N,))
        # simplified a bit: removed ensure_compile_time_eval
        # sparsity is generated explicitly, not from dense mat as it's more common (imho)
        sparsity = jax.experimental.sparse.BCOO((data, indices), shape=(N,N)) 

        sparse_jacrev_fn = jax.jit(sparsejac.jacrev(fn, sparsity=sparsity))
        dense_jacrev_fn = jax.jit(jax.jacrev(fn))

        assert jnp.all(sparse_jacrev_fn(x_rnd).todense() == dense_jacrev_fn(x_rnd))

        for jacrev_fn, desc in zip((dense_jacrev_fn, sparse_jacrev_fn), ("DENSE", "SPARSE")):
            print(f"Jacobian evaluation {desc}:")
            %timeit jacrev_fn(x_rnd).block_until_ready()
# Jacobian 10000 x 10000 (DENSE) size: 0.745 GiB
# *** 1. Processing with CPU ***
# Jacobian evaluation DENSE:
# 240 ms ± 9.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Jacobian evaluation SPARSE:
# 50.6 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# *** 2. Processing with GPU ***
# Jacobian evaluation DENSE:
# 10.8 ms ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Jacobian evaluation SPARSE:
# 105 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

It's interesting that even for this suuper sparse Jacobian GPU implementation is 2x SLOWER then CPU. Is it worth to open jax-bug btw ?

And second part - using sparse computations in the optimization with great speedup:

# Non-linear Least squares example (x**4 is used), N = 1000
tol = 1e-10
def scipy_jacrev_fn(x):
    J_sparse = sparse_jacrev_fn(x)
    return scipy.sparse.coo_array((, J_sparse.indices.T))

print("Jacobian computation with SCIPY-sparse conversion:")
%timeit scipy_jacrev_fn(x_rnd)

for i, desc in enumerate(("DENSE", "SPARSE")):
    start_time = time.time()
    res = scipy.optimize.least_squares(
        jac= dense_jacrev_fn if i == 0 else scipy_jacrev_fn,
        ftol=tol, xtol=tol, gtol=tol
    print(f"{desc} Non-Linear Least Squares {res.njev} iterations done in {time.time() - start_time:3.3f} sec")
    assert jnp.allclose(, 0, atol=1e-6)

#Jacobian computation with SCIPY-sparse conversion:
#2.41 ms ± 959 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
#DENSE Non-Linear Least Squares 14 iterations done in 13.038 sec
#SPARSE Non-Linear Least Squares 14 iterations done in 0.361 sec

I wonder if jaxopt will support sparse matrices...


After solving #11 imho it would be reasonable to switch jacrev -> jacfwd as it's generally faster if rows >= cols

Thanks for your comment. I would be open to a new example if it demonstrates something substantially new. With respect to performance and jacfwd vs. jacrev, a comment in the readme might suffice?

Thanks for fast reaction.

With respect to performance and jacfwd vs. jacrev, a comment in the readme might suffice?

Yeah, that's might help engineers who is not aware of "jacobian construction mechanics" (which is not needed to the lib-users imho) to get muuch more effective code.
Readme update works! Maybe a warning message or even an assert will do even better...

if it demonstrates something substantially new

The key(yet small) difference I'm suggesting is sparse construction

sparsity = jax.experimental.sparse.BCOO((data, indices), shape=(N,N))

which allows to work with really big Ns - where dense math fails.
I can adjust the snippet to gradually increase N up to the CPU-GPU memory allocation failure.
However feel free to ignore the proposal.

I added some comments to the readme. Based on some quick benchmarking, jacrev and jacfwd perform similarly on CPU. In general I suggest users test both functions to see what works best on their hardware and their problem. This is now stated in the readme.

Here is the colab benchmark I put together:

Yep, I've checked the colab, looks like that! For sqared matrices there seems to be no difference.

p.s. unless we have tall Jacobians where FWD could be so muuch faster...