PythonOT/POT

Typos in the formulas of gwggrad and solve_gromov_linesearch

Closed this issue · 0 comments

Describe the bug

The formulas in gwggrad and solve_gromov_linesearch have typos and do not match the cited references [12] and [24]. I also calculated the gradient by hand to confirm that POT has typos.

For concreteness, I'm using the following versions of the cited papers:
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
IN: https://proceedings.mlr.press/v48/peyre16.pdf

[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
IN: https://arxiv.org/pdf/1805.09114.pdf

To Reproduce

  1. Run the code below.

Code sample

Notes:

  1. The code below is a modification of plot_gromov.py from the examples gallery. I computed the GW distance two times, one using POT and another with my corrections implemented in gwggrad_mod and solve_gromov_linesearch_mod.
  2. The typos did not affect the result of the Gromov-Wasserstein distance in my example, but I wonder if making sub-optimal choices in line-search will affect the speed of convergence in more complicated calculations.
import scipy as sp
import numpy as np
import ot

# Import functions required in ot.gromov._gw
from ot.utils import list_to_array
from ot.optim import cg, solve_1d_linesearch_quad
from ot.backend import get_backend, NumpyBackend

from ot.gromov._utils import init_matrix, gwloss, gwggrad
from ot.gromov._gw import solve_gromov_linesearch

#############################################################################
#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#############################################################################

n_samples = 30  # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t

C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()

#############################################################################
#
# Parameters for dGW
# ---------------------------------------------
#############################################################################

p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]

loss_fun='square_loss'
symmetric=None
log=True
armijo=False
max_iter=1e4
tol_rel=1e-9
tol_abs=1e-9

#############################################################################
# 
# gwggrad and solve_gromov_linesearch with typos corrected
# ---------------------------------------------
#############################################################################
def gwggrad_mod(constC, hC1, hC2, T, nx=None):
    if nx is None:
        constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
        nx = get_backend(constC, hC1, hC2, T)
    
    return constC - 2 * nx.dot( nx.dot(hC1, T), hC2.T )

def solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M, reg,
                            alpha_min=None, alpha_max=None, nx=None, **kwargs):
    if nx is None:
        G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)

        if isinstance(M, int) or isinstance(M, float):
            nx = get_backend(G, deltaG, C1, C2)
        else:
            nx = get_backend(G, deltaG, C1, C2, M)
    
    dot_dG = nx.dot(nx.dot(C1, deltaG), C2.T)
    dot_G  = nx.dot(nx.dot(C1, G     ), C2.T)
    
    a = -2 * reg * nx.sum(dot_dG * deltaG)
    b = nx.sum(M * deltaG) + reg * (nx.sum(constC * deltaG) - 2 * nx.sum(dot_dG * G) - 2 * nx.sum(dot_G * deltaG))

    alpha = solve_1d_linesearch_quad(a, b)
    if alpha_min is not None or alpha_max is not None:
        alpha = np.clip(alpha, alpha_min, alpha_max)

    # the new cost is deduced from the line search quadratic function
    cost_G = cost_G + a * (alpha ** 2) + b * alpha

    return alpha, 1, cost_G

#############################################################################
#
# Compute Gromov-Wasserstein with modified functions
# ---------------------------------------------
#############################################################################
# cg for GW is implemented using numpy on CPU
np_ = NumpyBackend()

nx = get_backend(C1, C2, p, q)
p0, q0, C10, C20 = p, q, C1, C2
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)

######################################################################
# Define loss function, gradient and linesearch
# ---------------------------------------------
# NOTE: Using modified gwgrad and line_search
def f(G):
    return gwloss(constC, hC1, hC2, G, np_)

def df(G):
    return gwggrad_mod(constC, hC1, hC2, G, np_)

def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
    return solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_, **kwargs)
######################################################################

res_mod, log_mod = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs)
log_mod['gw_dist'] = nx.from_numpy(log_mod['loss'][-1], type_as=C1)
log_mod['u'] = nx.from_numpy(log_mod['u'], type_as=C1)
log_mod['v'] = nx.from_numpy(log_mod['v'], type_as=C1)

gw_mod = nx.from_numpy(res_mod, type_as=C1)


# Compute GW with the original function
gw0, log0 = ot.gromov.gromov_wasserstein(
    C1, C2, p, q, 'square_loss', verbose=True, log=True)

#############################################################################
#
# Compare gwggrad and solve_gromov_linesearch with their modified versions
# ---------------------------------------------
#############################################################################
G = G0
deltaG = np.random.rand(*G.shape)
cost_G = 0

grad_mod = gwggrad_mod(constC, hC1, hC2, G, np_)
grad = gwggrad(constC, hC1, hC2, G, np_)

linesearch_mod = solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_)
linesearch = solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_)

print()
print(f"dGW with func: {log0['gw_dist']}")
print(f"dGW with mods: {log_mod['gw_dist']}")
print("GW-distances agree:", log0['gw_dist'] == log_mod['gw_dist'])

print()
print('Gradients agree:', np.array_equal(grad_mod, grad))
print('Line-search results agree:', linesearch_mod == linesearch)

Expected behavior

The functions gwggrad and solve_gromov_linesearch should output the result of gwggrad_mod and solve_gromov_linesearch_mod, respectively.

Environment

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.5.7-100.fc37.x86_64-x86_64-with-glibc2.36
Python 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
NumPy 1.24.3
SciPy 1.11.1
POT 0.9.1

Additional context