PythonOT/POT

UnbalancedSinkhorn Transport fails to transform due to "nx.array_equal"

Opened this issue · 0 comments

Describe the bug

It appears UnbalancedSinkhornTransport fails on .transform() calls. It seems to fail on the nx.array_equal(self.xs_, Xs) call. It seems the transporter is failing to set nx or self.nx.

To Reproduce

  1. Init a UnbalancedSinkhornTransport object
  2. Call fit on the samples
  3. Call transform separately on the samples

Code sample

    import ot
    from ot.datasets import make_2D_samples_gauss
    OT = ot.da.UnbalancedSinkhornTransport()
    Xs = make_2D_samples_gauss(n=1000, m=10, sigma=[[2, 1], [1, 2]], random_state=42)
    Xt = make_2D_samples_gauss(n=1000, m=5, sigma=[[2, 1], [1, 2]], random_state=42)
    Xs = Xs.astype('float32')
    Xt = Xs + 0.5
    Xt = Xt.astype('float32')
    OT.fit(Xs, Xt)
    OT.transform(Xs)

Expected behavior

Transform should return the transported Xs sample.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Windows/Linux
  • Python version: 3.11
  • How was POT installed (source, pip, conda): PIP
    Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Linux-5.15.0-1061-realtime-x86_64-with-glibc2.35
Python 3.11.3 (main, May 15 2023, 15:45:52) [GCC 11.2.0]
NumPy 1.26.4
SciPy 1.10.1
POT 0.9.3

Additional context

Tested on 0.9.3 and 0.9.4