`converged` flag compatibility with `min_iterations` logic
zoepiran opened this issue · 3 comments
zoepiran commented
Describe the bug
- Given
min_iterations > inner_iterations
the final costs array will containinf
s (710, 717-720, 731). - The
inf
s implyconverged=False
following all_is_finite condition (201).
To Reproduce
rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d1, d2 = 5_000, 2_500, 1, 2
x = jax.random.uniform(rngs[0], (n, d1))
y = jax.random.uniform(rngs[1], (m, d2))
xx = jax.random.uniform(rngs[2], (n, d2))
yy = jax.random.uniform(rngs[3], (m, d2))
geom_x = PointCloud(x)
geom_y = PointCloud(y)
geom_xy = PointCloud(xx, yy)
prob = QuadraticProblem(geom_x, geom_y, geom_xy)
solver = jax.jit(LRGromovWasserstein(
rank=2,
min_iterations=100,
inner_iterations=10,
max_iterations=2_000,
progress_fn = utils.default_progress_fn()
))
ot_gwlr = solver(prob)
print(f"converged? {ot_gwlr.converged}")
print(f"costs:{ot_gwlr.costs[np.where(ot_gwlr.costs != -1)[0]]}")
Expected behavior
Convergence evaluation will ignore cost
values not computed.
michalk8 commented
Thanks for spotting this, indeed should be fixed
giovp commented
I tried with the code above
Details
from ott.geometry.pointcloud import PointCloud
from ott.problems.quadratic.quadratic_problem import QuadraticProblem
from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein
import jax
from ott import utils
import numpy as np
rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d1, d2 = 5_000, 2_500, 1, 2
x = jax.random.uniform(rngs[0], (n, d1))
y = jax.random.uniform(rngs[1], (m, d2))
xx = jax.random.uniform(rngs[2], (n, d2))
yy = jax.random.uniform(rngs[3], (m, d2))
geom_x = PointCloud(x)
geom_y = PointCloud(y)
geom_xy = PointCloud(xx, yy)
prob = QuadraticProblem(geom_x, geom_y, geom_xy)
solver = jax.jit(LRGromovWasserstein(
rank=2,
min_iterations=100,
inner_iterations=10,
max_iterations=2_000,
progress_fn = utils.default_progress_fn()
))
ot_gwlr = solver(prob)
print(f"converged? {ot_gwlr.converged}")
print(f"costs:{ot_gwlr.costs[np.where(ot_gwlr.costs != -1)[0]]}")
I noticed that sometime I get all inf
and still convergence=True
converged? False
costs:[inf inf inf inf inf inf inf inf inf inf]
what would be the right solution for raising the convergence flag correctly?
EDIT: convergence is False
didn't refresh interactive environment
zoepiran commented
I can't figure what actually happens here as the converged flag is still set using the same condition (here):
def converged(self) -> bool: # noqa: D102
return jnp.logical_and(
jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs))
)
Other than that the appearance of exactly 10 inf
s makes sense, given the construction in one_iteration()
as long as iteration < self.min_iterations
an inf
is set.