Sub-optimal contraction path when using broadcasting
Opened this issue · 1 comments
pimdh commented
Hi,
I'm not sure if there's an immediate solution possible, but it seems like opt_einsum first considers broadcasting, then optimizes the contraction path. This leads to sub-optimal results:
import opt_einsum
print(opt_einsum.__version__)
print(opt_einsum.contract_path("ijk,bj,bk->bi", (32, 32, 32), (10000, 32), (1, 32), optimize="optimal", shapes=True))
print(opt_einsum.contract_path("ijk,bj,k->bi", (32, 32, 32), (10000, 32), (32,), optimize="optimal", shapes=True))
Gives
v3.3.0+24.g1a984b7
([(1, 2), (0, 1)], Complete contraction: ijk,bj,bk->bi
Naive scaling: 4
Optimized scaling: 4
Naive FLOP count: 9.830e+8
Optimized FLOP count: 6.656e+8
Theoretical speedup: 1.477e+0
Largest intermediate: 1.024e+7 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
3 0 bk,bj->bkj ijk,bkj->bi
4 TDOT bkj,ijk->bi bi->bi)
([(0, 2), (0, 1)], Complete contraction: ijk,bj,k->bi
Naive scaling: 4
Optimized scaling: 3
Naive FLOP count: 9.830e+8
Optimized FLOP count: 2.055e+7
Theoretical speedup: 4.785e+1
Largest intermediate: 3.200e+5 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
3 GEMM k,ijk->ij bj,ij->bi
3 GEMM ij,bj->bi bi->bi)
We see that in the first case, the third tensor is broadcasted to (b, 32) and then the optimizer decides it's best to contract the latter two tensors. Ideally, we'd strip off the to-be-broadcasted dim from the third tensor, which allows for a much faster computation, as shown in the second case.
Any ideas on how this could be addressed? I understand that this doesn't involve just choosing a contraction path, so might not be solvable by this library.
Thanks!