jax-ml/jax

Small Einsum is hanging

Opened this issue · 5 comments

Description

Here's a small case where np.einsum works but jnp.einsum does not

import numpy as np
import jax.numpy as jnp

formula = 'a,c,d,db,ab,cb,ac,cd,ad,b->dbc'

arrays = [np.random.rand(*(2,)*len(key)) for key in formula.split('->')[0].split(',')]

np.einsum(formula, *arrays)
array([[[6.26532636e-05, 9.94054312e-04],
        [3.24902199e-05, 2.90052489e-03]],

       [[1.21862902e-05, 9.85561040e-05],
        [2.81959491e-06, 1.77314102e-04]]])


jnp.einsum(formula, *arrays)  # this hangs and does not complete

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='849fd340451c', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')

Thanks for the report! It looks like a bug in opt_einsum, which jnp.einsum uses. Here's a more direct reproduction of the issue:

import opt_einsum
opt_einsum.contract_path(
    formula, *arrays, einsum_call=True, use_blas=True, optimize='optimal')

It would be worth reporting upstream I think – would you like to report the issue there, or would you like us to take over?

Ah I see interesting, I guess in that case I can get immediately unblocked by just changing the optimize kwarg for now. Went ahead and reported dgasmith/opt_einsum#243

According to dgasmith/opt_einsum#243, setting path='auto' might be a preferable default. As far as I understand, it defaults to 'optimal' if the number of components is small and will use something different if that will not run in a reasonable amount of time.

#25055 changes to optimize='auto' for multi_dot; perhaps we should do the same for einsum.