UnbalancedSinkhorn Transport fails to transform due to "nx.array_equal"
smestern opened this issue · 0 comments
smestern commented
Describe the bug
It appears UnbalancedSinkhornTransport fails on .transform() calls. It seems to fail on the nx.array_equal(self.xs_, Xs)
call. It seems the transporter is failing to set nx or self.nx.
To Reproduce
- Init a UnbalancedSinkhornTransport object
- Call fit on the samples
- Call transform separately on the samples
Code sample
import ot
from ot.datasets import make_2D_samples_gauss
OT = ot.da.UnbalancedSinkhornTransport()
Xs = make_2D_samples_gauss(n=1000, m=10, sigma=[[2, 1], [1, 2]], random_state=42)
Xt = make_2D_samples_gauss(n=1000, m=5, sigma=[[2, 1], [1, 2]], random_state=42)
Xs = Xs.astype('float32')
Xt = Xs + 0.5
Xt = Xt.astype('float32')
OT.fit(Xs, Xt)
OT.transform(Xs)
Expected behavior
Transform should return the transported Xs sample.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Windows/Linux
- Python version: 3.11
- How was POT installed (source,
pip
,conda
): PIP
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.15.0-1061-realtime-x86_64-with-glibc2.35
Python 3.11.3 (main, May 15 2023, 15:45:52) [GCC 11.2.0]
NumPy 1.26.4
SciPy 1.10.1
POT 0.9.3
Additional context
Tested on 0.9.3 and 0.9.4