ott-jax/ott

`converged` flag compatibility with `min_iterations` logic

zoepiran opened this issue · 3 comments

Describe the bug

  1. Given min_iterations > inner_iterations the final costs array will contain infs (710, 717-720, 731).
  2. The infs imply converged=False following all_is_finite condition (201).

To Reproduce

rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d1, d2 = 5_000, 2_500, 1, 2
x = jax.random.uniform(rngs[0], (n, d1))
y = jax.random.uniform(rngs[1], (m, d2))
xx = jax.random.uniform(rngs[2], (n, d2))
yy = jax.random.uniform(rngs[3], (m, d2))

geom_x = PointCloud(x)
geom_y = PointCloud(y)
geom_xy = PointCloud(xx, yy)
prob = QuadraticProblem(geom_x, geom_y, geom_xy)
solver = jax.jit(LRGromovWasserstein(
     rank=2, 
     min_iterations=100, 
     inner_iterations=10, 
     max_iterations=2_000,
     progress_fn = utils.default_progress_fn()
))

ot_gwlr = solver(prob)

print(f"converged? {ot_gwlr.converged}")
print(f"costs:{ot_gwlr.costs[np.where(ot_gwlr.costs != -1)[0]]}")

Expected behavior
Convergence evaluation will ignore cost values not computed.

Screenshots
Screenshot 2024-02-27 at 12 15 49

Thanks for spotting this, indeed should be fixed

I tried with the code above

Details
from ott.geometry.pointcloud import PointCloud
from ott.problems.quadratic.quadratic_problem import QuadraticProblem
from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein
import jax
from ott import utils
import numpy as np

rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d1, d2 = 5_000, 2_500, 1, 2
x = jax.random.uniform(rngs[0], (n, d1))
y = jax.random.uniform(rngs[1], (m, d2))
xx = jax.random.uniform(rngs[2], (n, d2))
yy = jax.random.uniform(rngs[3], (m, d2))

geom_x = PointCloud(x)
geom_y = PointCloud(y)
geom_xy = PointCloud(xx, yy)
prob = QuadraticProblem(geom_x, geom_y, geom_xy)
solver = jax.jit(LRGromovWasserstein(
     rank=2, 
     min_iterations=100, 
     inner_iterations=10, 
     max_iterations=2_000,
     progress_fn = utils.default_progress_fn()
))

ot_gwlr = solver(prob)

print(f"converged? {ot_gwlr.converged}")
print(f"costs:{ot_gwlr.costs[np.where(ot_gwlr.costs != -1)[0]]}")

I noticed that sometime I get all inf and still convergence=True

converged? False
costs:[inf inf inf inf inf inf inf inf inf inf]

what would be the right solution for raising the convergence flag correctly?

EDIT: convergence is False didn't refresh interactive environment

I can't figure what actually happens here as the converged flag is still set using the same condition (here):

 def converged(self) -> bool:  # noqa: D102
    return jnp.logical_and(
        jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs))
    )

Other than that the appearance of exactly 10 infs makes sense, given the construction in one_iteration()
as long as iteration < self.min_iterations an inf is set.