Typos in the formulas of gwggrad and solve_gromov_linesearch
Closed this issue · 0 comments
Describe the bug
The formulas in gwggrad and solve_gromov_linesearch have typos and do not match the cited references [12] and [24]. I also calculated the gradient by hand to confirm that POT has typos.
For concreteness, I'm using the following versions of the cited papers:
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
IN: https://proceedings.mlr.press/v48/peyre16.pdf
[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
IN: https://arxiv.org/pdf/1805.09114.pdf
To Reproduce
- Run the code below.
Code sample
Notes:
- The code below is a modification of
plot_gromov.py
from the examples gallery. I computed the GW distance two times, one using POT and another with my corrections implemented ingwggrad_mod
andsolve_gromov_linesearch_mod
. - The typos did not affect the result of the Gromov-Wasserstein distance in my example, but I wonder if making sub-optimal choices in line-search will affect the speed of convergence in more complicated calculations.
import scipy as sp
import numpy as np
import ot
# Import functions required in ot.gromov._gw
from ot.utils import list_to_array
from ot.optim import cg, solve_1d_linesearch_quad
from ot.backend import get_backend, NumpyBackend
from ot.gromov._utils import init_matrix, gwloss, gwggrad
from ot.gromov._gw import solve_gromov_linesearch
#############################################################################
#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#############################################################################
n_samples = 30 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)
C1 /= C1.max()
C2 /= C2.max()
#############################################################################
#
# Parameters for dGW
# ---------------------------------------------
#############################################################################
p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]
loss_fun='square_loss'
symmetric=None
log=True
armijo=False
max_iter=1e4
tol_rel=1e-9
tol_abs=1e-9
#############################################################################
#
# gwggrad and solve_gromov_linesearch with typos corrected
# ---------------------------------------------
#############################################################################
def gwggrad_mod(constC, hC1, hC2, T, nx=None):
if nx is None:
constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
nx = get_backend(constC, hC1, hC2, T)
return constC - 2 * nx.dot( nx.dot(hC1, T), hC2.T )
def solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M, reg,
alpha_min=None, alpha_max=None, nx=None, **kwargs):
if nx is None:
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
if isinstance(M, int) or isinstance(M, float):
nx = get_backend(G, deltaG, C1, C2)
else:
nx = get_backend(G, deltaG, C1, C2, M)
dot_dG = nx.dot(nx.dot(C1, deltaG), C2.T)
dot_G = nx.dot(nx.dot(C1, G ), C2.T)
a = -2 * reg * nx.sum(dot_dG * deltaG)
b = nx.sum(M * deltaG) + reg * (nx.sum(constC * deltaG) - 2 * nx.sum(dot_dG * G) - 2 * nx.sum(dot_G * deltaG))
alpha = solve_1d_linesearch_quad(a, b)
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
# the new cost is deduced from the line search quadratic function
cost_G = cost_G + a * (alpha ** 2) + b * alpha
return alpha, 1, cost_G
#############################################################################
#
# Compute Gromov-Wasserstein with modified functions
# ---------------------------------------------
#############################################################################
# cg for GW is implemented using numpy on CPU
np_ = NumpyBackend()
nx = get_backend(C1, C2, p, q)
p0, q0, C10, C20 = p, q, C1, C2
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)
######################################################################
# Define loss function, gradient and linesearch
# ---------------------------------------------
# NOTE: Using modified gwgrad and line_search
def f(G):
return gwloss(constC, hC1, hC2, G, np_)
def df(G):
return gwggrad_mod(constC, hC1, hC2, G, np_)
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_, **kwargs)
######################################################################
res_mod, log_mod = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs)
log_mod['gw_dist'] = nx.from_numpy(log_mod['loss'][-1], type_as=C1)
log_mod['u'] = nx.from_numpy(log_mod['u'], type_as=C1)
log_mod['v'] = nx.from_numpy(log_mod['v'], type_as=C1)
gw_mod = nx.from_numpy(res_mod, type_as=C1)
# Compute GW with the original function
gw0, log0 = ot.gromov.gromov_wasserstein(
C1, C2, p, q, 'square_loss', verbose=True, log=True)
#############################################################################
#
# Compare gwggrad and solve_gromov_linesearch with their modified versions
# ---------------------------------------------
#############################################################################
G = G0
deltaG = np.random.rand(*G.shape)
cost_G = 0
grad_mod = gwggrad_mod(constC, hC1, hC2, G, np_)
grad = gwggrad(constC, hC1, hC2, G, np_)
linesearch_mod = solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_)
linesearch = solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_)
print()
print(f"dGW with func: {log0['gw_dist']}")
print(f"dGW with mods: {log_mod['gw_dist']}")
print("GW-distances agree:", log0['gw_dist'] == log_mod['gw_dist'])
print()
print('Gradients agree:', np.array_equal(grad_mod, grad))
print('Line-search results agree:', linesearch_mod == linesearch)
Expected behavior
The functions gwggrad and solve_gromov_linesearch should output the result of gwggrad_mod and solve_gromov_linesearch_mod, respectively.
Environment
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.5.7-100.fc37.x86_64-x86_64-with-glibc2.36
Python 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
NumPy 1.24.3
SciPy 1.11.1
POT 0.9.1