PythonOT/POT

Improvements to pointwize and sampled GW variants

Opened this issue · 1 comments

Currently several algorithms to compute or estimate Gromov-Wasserstein distance are provided, so the user has lots of freedom to experiment with algorithms which are appropriate to their particular distribution size, accuracy requirements, loss function, etc.

However, the pointwise_gromov_wasserstein and sampled_gromov_wasserstein functions are substantially slower than gromov_wasserstein for analogous cases. Our lab is working with distributions of size N=100 and on a 20 core machine, the gromov_wasserstein function takes about 17-20 milliseconds. For pointwise_gromov_wasserstein with 5 iterations, log=False, max_iter=5, it takes between 40 and 80 milliseconds.

Granted, the original paper on Sampled Gromov Wasserstein points out that its advantage is strongest for distributions with N >> 100, and strongest when we are not talking about the square loss. However I do not think this explains the performance difference. I suspect a large share of the performance difference is due to the slowness of the user-supplied loss function being interpreted in a list comprehension each stage of the loop.

I propose that the interface for pointwise_gromov_wasserstein, sampled_gromov_wasserstein and GW_distance_estimation expose a way that users can select from a fixed list of loss operations, including square loss and absolute value loss, and internally these will be implemented in a vectorized way using a performant backend.

Hello @patrick-nicodemus this makes sens you could give loss either as string for pre computed loss or a function for more geenral ones. feel free to propose a PR and try to respect the API for GW.