Small Einsum is hanging
Opened this issue · 5 comments
Description
Here's a small case where np.einsum works but jnp.einsum does not
import numpy as np
import jax.numpy as jnp
formula = 'a,c,d,db,ab,cb,ac,cd,ad,b->dbc'
arrays = [np.random.rand(*(2,)*len(key)) for key in formula.split('->')[0].split(',')]
np.einsum(formula, *arrays)
array([[[6.26532636e-05, 9.94054312e-04],
[3.24902199e-05, 2.90052489e-03]],
[[1.21862902e-05, 9.85561040e-05],
[2.81959491e-06, 1.77314102e-04]]])
jnp.einsum(formula, *arrays) # this hangs and does not complete
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='849fd340451c', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
Thanks for the report! It looks like a bug in opt_einsum
, which jnp.einsum
uses. Here's a more direct reproduction of the issue:
import opt_einsum
opt_einsum.contract_path(
formula, *arrays, einsum_call=True, use_blas=True, optimize='optimal')
It would be worth reporting upstream I think – would you like to report the issue there, or would you like us to take over?
Ah I see interesting, I guess in that case I can get immediately unblocked by just changing the optimize kwarg for now. Went ahead and reported dgasmith/opt_einsum#243
According to dgasmith/opt_einsum#243, setting path='auto' might be a preferable default. As far as I understand, it defaults to 'optimal' if the number of components is small and will use something different if that will not run in a reasonable amount of time.