Incorrect computation of `cost_correction` matrix in `ot.da.EMDTransport`
Opened this issue · 2 comments
Describe the bug
It seems that the cost_correction
matrix is computed incorrectly. This is the current code that can be found here:
# labels_match is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has False if ys[i] != yt[i]
label_match = (ys[:, None] - yt[None, :]) != 0
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction = label_match * missing_labels * self.limit_max
The issues:
- First, the comment says that
label_match
is False ifys[i] != yt[i]
.
However, ifys[i] != yt[i]
then(ys[:, None] - yt[None, :]) != 0
will be True, hencelabel_match
will be True - although the labels do not match (the naming is confusing in this case). Therefore, either- the variable should be named
label_mismatch
and the comment should be fixed OR - we check for equality
label_match = (ys[:, None] - yt[None, :]) == 0
and flip the value incost_correction
, i.e.cost_correction = (1 - label_match) * ...
- the variable should be named
- Second,
cost_correction = label_match * missing_labels * self.limit_max
will apply a cost correction only ifmissing_labels
is True. However, it must not correct ifmissing_labels
is True - hence, we need to flip it to... * (1 - missing_labels ) * ...
Therefore, I'd propose the following change
# label_mismatch is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has True if ys[i] != yt[i]
label_mismatch = (ys[:, None] - yt[None, :]) != 0
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction = label_mismatch * (1 - missing_labels) * self.limit_max
Happy to send the corresponding PR if you agree.
Screenshots
The following screenshots show the effect of flipping the missing_labels
value. Here we map samples across multiple Gaussian distributions with 2 labels (p = 1 and p = 2). All labels are given. Without the fix, the transport plans are not computed correctly. With the fix, only samples from the same target class are linked.
Environment (please complete the following information):
Linux-4.18.0-372.75.1.el8_6.x86_64-x86_64-with-glibc2.28
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
NumPy 2.0.0
SciPy 1.14.0
POT 0.9.4 (pip installed)
Hello @martinrohbeck and thanks for the issue. I think this is indeed a bug especially the missing_label weight. I am curious of the input of @kachayev who did the original code if I remember well but I think we would welcome a PR with your two fixes.
Hi @martinrohbeck,
Thanks for the report! Your suggestion sounds reasonable to me. There are a couple of test cases in the test suite designed to verify that the vectorized version of the algorithm produces the same results as the previous version of the code. If you find that these tests don’t fail while working on the PR, it would indicate that the discrepancy was introduced during the vectorization process. Otherwise, it would be worth revisiting the logic in the older code. Either way, I’ll be glad to review the PR.