Removing linear search as a step in GW descent
Opened this issue · 2 comments
TL; DR - I would like the GW function to provide an option to skip the line search in favor of the naive update step, because in my experiments line search is unambiguously worse.
In the currently implemented gradient descent algorithm for GW, the current implementation is something like the following. At time
We write
This cleverness with the line search goes beyond the theory developed in Peyre et. al., Gromov-Wasserstein Averaging of Kernel and Distance Matrices, ICML 2016.
I am concerned that that this is somewhat premature optimization which is not worth it. In order to do this line search, one must solve for the coefficients of the quadratic formula, which involves four matrix multiplications by my count. Compare this to directly computing
I have experimented with the behavior of this algorithm on a real world data set, computing a number of Gromov-Wasserstein distances between point clouds equipped with the uniform distribution. In this experiment, the line search never found a solution
I suggest the devs do their own experiments and see if there is a dataset in which the line search performs better; I would be interested in seeing such a dataset, perhaps one not using the uniform distribution on points. I request that there be an option for the user to disable the line search in favor of the more naive update step.
First, hello.
The exact line-search step can be skipped using the armijo rule, by setting 'armijo=True'. Other rules may be added as features but would be more costly most likely. However for this solver to be a proper conditional gradient, the step size
Hello @patrick-nicodemus ,
If you find a better line-search (at least on some type of data) we encourage you to implement it and add it as an alternative in the GW solvers with a PR. @cedricvincentcuaz did recently a big revamp of the GW and CG solvers that makes this easier to do (basically just give your implementation of line-search to the solver). As @cedricvincentcuaz said, you can use the armijo linesearch or implement your own (we did not implement the traditional CG lineserach in 1/sqrt(k) for instance).
Still note that in our experience the kind of claim that one lineseach is better is very data-dependent because the problem remains highly non convex that is also why we did a real line-search by default.