jax-ml/jax

Allow einsum to support naive contraction strategy

Opened this issue · 12 comments

I would like to compute an einsum according to the following formula:

n = 8192
arrays = [jax.random.normal(key=jax.random.PRNGKey(0), shape=(n, n)) for _ in range(6)]
formula = 'ij,ik,il,jk,jl,kl->ij'

I want to express the computation as 4 nested for loops over indices i, j, k, l without creating any intermediate arrays. As far as einsum_path is concerned, I can do this by passing the einsum path directly as [(0, 1, 2, 3, 4, 5)] via the optimize kwarg).

>>> jax.numpy.einsum_path(formula,` *arrays, optimize=[(0,1,2,3,4,5)])
Complete contraction:  ij,ik,il,jk,jl,kl->ij
          Naive scaling:  4
      Optimized scaling:  4
       Naive FLOP count:  2.702e+16
   Optimized FLOP count:  2.702e+16
    Theoretical speedup:  1.000e+0
   Largest intermediate:  6.711e+7 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    4              0  kl,jl,jk,il,ik,ij->ij                                ij->ij)

However, when I try to do the einsum, I get this NotImplementedError with a comment that says "# if this is actually reachable, open an issue!"

https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9775

>>> ans = jnp.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
>>> ans.block_until_ready()

I think your path specification is invalid. For example, if you pass it to NumPy, you get this error:

np.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
Traceback (most recent call last):
  File "/Users/vanderplas/github/google/jax/tmp.py", line 9, in <module>
    np.einsum(formula, *arrays, optimize=[(0,1,2,3,4,5)])
  File "/Users/vanderplas/.local/share/virtualenvs/jax-LBbfM5ix/lib/python3.12/site-packages/numpy/_core/einsumfunc.py", line 1441, in einsum
    operands, contraction_list = einsum_path(*operands, optimize=optimize,
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vanderplas/.local/share/virtualenvs/jax-LBbfM5ix/lib/python3.12/site-packages/numpy/_core/einsumfunc.py", line 878, in einsum_path
    raise TypeError("Did not understand the path: %s" % str(path_type))
TypeError: Did not understand the path: [(0, 1, 2, 3, 4, 5)]

Thank you for taking a look! My understanding is that this path is the default behavior for numpy. I.e., it corresponds to the basic implementation that you have in

https://github.com/jax-ml/jax/blob/main/tests/lax_numpy_einsum_test.py#L295

It is much more memory efficient than doing the einsum as a sequence of dot_general's in this case, which from my investigation is hard-coded into the JAX implementation. It makes sense because dot_general is very highly optimized, but being able to get the more memory-efficient behavior seems desirable in some settings.

I prototyped a version of this using a sequence of nested jax.lax.scan calls, but it was ugly and I don't think the most performant. I also played around with using Jax.vmap over the indices (i, j) and using jnp.einsum using the per-element path

Complete contraction: ij,ik,il,jk,jl,kl->ij
[vmap] Per-Row contraction: j,k,l,jk,jl,kl->j
[double vmap] Per-element contraction: ,k,l,k,l,kl->

It was pretty cool to use JAX's abstractions to achieve this, and the vmap implementation did have better performance characteristics than jnp.einsum in this case, but I still think it uses more memory than the naive approach.

If Jax.lax.map supported the in_axes argument, I think that would help, since I could just replace my usage of vmap with map.

Here is a basic implementation of the naive strategy in terms of jax.vmap and jax.lax.scan, specialized to the formula 'ij,ik,il,jk,jl,kl->ij'.

import jax
import jax.numpy as jnp
import time

def inner_einsum(*arrays):
  # computes einsum for ,k,l,k,l,kl->
  # Does not create any intermediate arrays

  A, B, C, D, E, F = arrays
  K, L = B.size, C.size

  def foo(partial1, k):
    def bar(partial2, l):
      return partial2 + C[l] * E[l] * F[k, l], ()
    return partial1 + B[k] * D[k] * jax.lax.scan(bar, 0, jnp.arange(L))[0], ()
  return A * jax.lax.scan(foo, 0, jnp.arange(K))[0]


@jax.jit
def vmap_einsum(*arrays):
  # computes einsum for ij,ik,il,jk,jl,kl->ij naively
  # No memory overhead.  Vectorized across output cells.

  return jax.vmap(
      jax.vmap(inner_einsum, in_axes=(0, None, None, 0, 0, None)),
      in_axes=(0, 0, 0, None, None, None)
  )(*arrays)

@jax.jit
def default_einsum(*arrays):
  return jnp.einsum('ij,ik,il,jk,jl,kl->ij', *arrays)

when I benchmark it using n x n arrays for n = [128, 256, 512, 1024] here is what I get for timing information (measured in seconds, not counting JIT compilation). The story is that jnp.einsum is faster up to n=512, but fails at n=1024, while the naive approach implemented above still runs, albeit it takes more time than I'd like.

n=128
vmap_einsum 0.14367246627807617
default_einsum 0.002198457717895508

n=256
vmap_einsum 0.7639327049255371
default_einsum 0.017670154571533203

n=512
vmap_einsum 4.290320158004761
default_einsum 0.24642205238342285

n=1024
vmap_einsum 35.70246410369873
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-5-cc97c95bcc5f>](https://colab.corp.google.com/drive/1s4c4kdOR2VNVoKIyZHlj0M9gh0wm_Am9#) in <cell line: 0>()
      4 
      5   for einsum_fn in EINSUM_IMPLS:
----> 6     jax.block_until_ready(einsum_fn(*arrays))
      7     t0 = time.time()
      8     jax.block_until_ready(einsum_fn(*arrays))

    [... skipping hidden 5 frame]

Here's another impl one can throw into the mix: scan_einsum where we strip out a non-output axis and sequentially compute + add up the resulting smaller einsums, as follows:

@jax.jit
def scan_einsum(*arrays):
  # we will scan over k and build up a running sum

  A, B, C, D, E, F = arrays
  K = B.shape[1]
  zeros = jnp.zeros(A.shape)

  def add_small_einsum(partial, k):
    # einsum with k stripped out
    # i,j,i,il,j,jl,l->ij
    return partial + jnp.einsum('ij,i,il,j,jl,l->ij', A, B[:,k], C, D[:,k], E, F[k,:]), ()

  return jax.lax.scan(add_small_einsum, zeros, jnp.arange(K))[0]

Benchmarks show that this is significantly better than the vmap_einsum above. And it's even better than jnp.einsum beyond n=256

n=128
vmap_einsum 0.13236498832702637
scan_einsum 0.0034575462341308594
default_einsum 0.0014224052429199219

n=256
vmap_einsum 0.7413990497589111
scan_einsum 0.011484861373901367
default_einsum 0.018535137176513672

n=512
vmap_einsum 4.2713000774383545
scan_einsum 0.04682159423828125
default_einsum 0.23777055740356445

n=1024
vmap_einsum 35.49849033355713
scan_einsum 0.47335124015808105
XlaRuntimeError

If anyone is interested, I typed up this exploration on my blog:

https://www.ryanhmckenna.com/2024/11/exploring-multi-input-einsums-in-jax.html

Thanks for exploring this – are you running benchmarks on GPU/TPU as well, or just CPU? The reason I ask is that scan has a pretty big performance penalty on accelerators (essentially each iteration is its own kernel launch) so I expect any efficiency gains on CPU will not transfer to GPU or TPU.

These tests were done in a colab sandbox with GPU, happy to do some more benchmarking if there's something specific you'd like to see

OK, thanks.

Overall, I tend to be -1 on changes like this. It greatly complicates things on the JAX side in order to make up for deficiencies in the compiler. The compiler behavior may be improved in the future, at which point we would needlessly be generating more complicated code with no clear way of alerting ourselves that this is the case.

Is this a compiler deficiency though? My understanding is it is a JAX implementation choice that leads to this behavior, specifically https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9773, which implements einsum in terms of a "_dot_general" primitive, which I believe means the einsum is calculated as a sequence of pairwise contractions. Even if the compiler was better at _dot_general, it wouldn't get around the intractability of storing the required n^3 sized intermediates in this case.

Happy to keep this alternate implementation local to where I need it though to keep the jax impls simpler though.

The compiler often fuses sequences of operations into single kernels to avoid storing intermediates. There may already be fusion paths for sequences of dot_general in some situations, but I'm not sure. scan is a much less specific primitive than dot general, so emitting scan would hamper the ability of the compiler to make such optimizations in the future.

I'm not saying your code is not useful; I think the approach probably makes sense in some situations. I just don't think it's a good fit for JAX's einsum implementation. (If @mattjj disagrees though, I'm happy to defer to his judgment here).

Ah I see that makes sense, do you think I should open up an issue at https://github.com/openxla/xla in that case?