ott-jax/ott

Add an OTT `check_grads` function

Daniel-Packer opened this issue · 0 comments

Is your feature request related to a problem? Please describe.
When implementing new features, we provide tests of the gradients to demonstrate that the new feature is differentiable by autograd. In order to test all the gradients of a function, I end up duplicating a lot of code blocks. Even if I write a function to check the gradients to avoid that duplication, I will need to duplicate that function definition across feature implementations, which is pretty much the same issue. We could just group these into a single function in ott.tools that can be used universally.

Describe the solution you'd like
We could implement a version of check_grads in ott.tools as done in the internal jax utilities here: google/jax#2648.

Describe alternatives you've considered
We could also just use the jax internal check_grads function directly. I don't know if there are compatibility issues there.

Additional context
The docstring for check_grads has been added too, which makes it sound like we could just use it, but I'm not sure: google/jax#2656.