ott-jax/ott

FGW get linear and quadratic term

gjhuizing opened this issue · 7 comments

Hello,

Given a GWOutput, am I right that out.reg_gw_cost can be split as linear + fused_penalty*quad? Is there a good way to get these linear and quadratic contributions?

I tried something but it involves instantiating out.matrix which is probably not very efficient...

Thanks a lot!

Best

GJ

yes, at this point, those are no tracked.

The small difficulty is that out.reg_gw_cost also includes entropic regularization.

I am afraid the easiest approach could be to use the same pattern as here but feed the geom_xy to transport_at_geom instead. I would hope that if your geom_xy is a LRC (low rank cost) all show go smoothly.

Thanks a lot, I'll try that!

By the way, is out.reg_gw_cost equal to

  • quadratic(pi^\star) + fused_penalty*linear(pi^\star) - eps*entropy(pi^\star)
  • or to quadratic(pi^\star) + fused_penalty*linear(pi^\star) - eps*(1 + fused_penalty)*entropy(pi^\star) ?

It should be the first. In this iteration,
costs = self.costs.at[iteration].set(linear_sol.reg_ot_cost)
the cost is set using the linear solution, whose ground cost is itself set to be the sum of two matrices, the linearized GW cost ~C_XX P C_YY + fused_penalty times linear term. Then one adds epsilon times regularization.

Okay, that makes sense, thank you :)

Hi, me again! This made me realize that for the GW solver, epsilon is not defined in the geometries (as in the Sinkhorn solver) but in the args of WassersteinSolver, right?

IMHO this is not super clear from the docs and it could be worth setting epsilon=1.0 directly in the args of GW (as done in the LRGW solver actually)

This question has been around for a while, namely whether epsilon should be a part of Geometry or part of solver.

My thinking is that, for a linear solver, the scale of epsilon does not make sense "on its own", a solver cannot really know how to set epsilon, it should be relative to the geometry's points, and cost-function. This is why querying it from Geometry would make more sense.

On the other hand, for a quadratic solver, there is no such a thing as a "stable" Geometry: the linearization means that the cost matrix is refreshed (and quite unstable) at every iteration. Therefore, this is why one needs to move it up, at the level of the solver.

As for the discrepancy between LR and GWLR API w.r.t. epsilon, I think this is a minor bug. Both inherit from WassersteinSolver, and epsilon is there. Therefore, we should either expose it in both (in list of arguments) or hide it from both. Maybe exposing it in both would be preferable. @michalk8 do you have an opinion?

closing this for now! unless @michalk8 has an additional comment