`ot.emd2()` does not work as expected with empty weights if the JAX backend is used
Closed this issue · 4 comments
Describe the bug
Per documentation of ot.emd2()
, uniform weights will be used if empty lists are passed as the arguments. However, doing so with the JAX backend will cause broadcasting issue.
To Reproduce
Simulate some data first:
import jax
from jax import numpy as jnp
key = jax.random.PRNGKey(1)
x = jax.random.normal(key, (100, 2))
y = jax.random.normal(key, (100, 2))
With numpy
backend, the following works without an issue:
from opt_einsum import contract
M = contract('mi,ni->mn', x, y, backend='numpy') ** 2.
emt = np.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis
However, errors occur once we switch to jnp
:
M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt = jnp.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis
Partial error message:
File [c:\ProgramData\anaconda3\Lib\site-packages\ot\lp\__init__.py:567](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/lp/__init__.py:567), in emd2.<locals>.f(b)
559 warnings.warn(
560 "Input histogram consists of integer. The transport plan will be "
561 "casted accordingly, possibly resulting in a loss of precision. "
(...)
564 stacklevel=2
565 )
566 G = nx.from_numpy(G, type_as=type_as)
--> 567 cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
568 (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
569 nx.from_numpy(v - np.mean(v), type_as=type_as), G))
571 check_result(result_code)
572 return cost
File [c:\ProgramData\anaconda3\Lib\site-packages\ot\backend.py:1392](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/backend.py:1392), in JaxBackend.set_gradients(self, val, inputs, grads)
1389 ravelled_inputs, _ = ravel_pytree(inputs)
1390 ravelled_grads, _ = ravel_pytree(grads)
-> 1392 aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
1393 aux = aux - jax.lax.stop_gradient(aux)
1395 val, = jax.tree_map(lambda z: z + aux, (val,))
File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\array_methods.py:256](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/array_methods.py:256), in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
254 args = (other, self) if swap else (self, other)
255 if isinstance(other, _accepted_binop_types):
--> 256 return binary_op(*args)
257 # Note: don't use isinstance here, because we don't want to raise for
258 # subclasses, e.g. NamedTuple objects that may override operators.
259 if type(other) in _rejected_binop_types:
[... skipping hidden 12 frame]
File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\ufuncs.py:97](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/ufuncs.py:97), in _maybe_bool_binop.<locals>.fn(x1, x2)
95 def fn(x1, x2, /):
96 x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
---> 97 return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
[... skipping hidden 7 frame]
File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\lax\lax.py:1591](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/lax/lax.py:1591), in broadcasting_shape_rule(name, *avals)
1589 result_shape.append(non_1s[0])
1590 else:
-> 1591 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1592 f'{", ".join(map(str, map(tuple, shapes)))}.')
1594 return tuple(result_shape)
TypeError: mul got incompatible shapes for broadcasting: (10000,), (10200,).
Possible solution:
This problem can be avoided if we generate the uniform weight by ourselves:
M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt0 = jnp.ones((M.shape[0],)) / M.shape[0]
emt1 = jnp.ones((M.shape[1],)) / M.shape[1]
Wass_dis = ot.emd2(emt0, emt1, M=M)
Wass_dis # correct result
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Windows
- Python version: 3.11.4
- How was POT installed (source,
pip
,conda
):pip
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Windows-10-10.0.22621-SP0
Python 3.11.4 | packaged by Anaconda, Inc. | (main, Jul 5 2023, 13:38:37) [MSC v.1916 64 bit (AMD64)]
NumPy 1.24.3
SciPy 1.10.1
POT 0.9.1
Hello @Francis-Hsu and thanks for the feedback. Could you do a quick check and see if there is a bug when you provide actual empty python list (a=[]
) instead of empty jax arrays?
Unless I'm mistaken the documentation states "empty list" and the function should handle this well for any backend.
also note that for the new API wheights are now optional and there is no need for emty lists:
Wass_dis = ot.solve(M).value
Hello @Francis-Hsu and thanks for the feedback. Could you do a quick check and see if there is a bug when you provide actual empty python list (
a=[]
) instead of empty jax arrays?Unless I'm mistaken the documentation states "empty list" and the function should handle this well for any backend.
Hi @rflamary. Thank you for the feedback. If I use ot.emd2([], [], M=M)
I will get the type checking error:
ValueError: All array should be from the same type/backend. Current types are : [<class 'jaxlib.xla_extension.ArrayImpl'>, <class 'numpy.ndarray'>, <class 'numpy.ndarray'>]
But indeed the ot.solve(M)
interface is much more convenient. I didn't know about it until now :P