Add an OTT `check_grads` function
Daniel-Packer opened this issue · 0 comments
Is your feature request related to a problem? Please describe.
When implementing new features, we provide tests of the gradients to demonstrate that the new feature is differentiable by autograd. In order to test all the gradients of a function, I end up duplicating a lot of code blocks. Even if I write a function to check the gradients to avoid that duplication, I will need to duplicate that function definition across feature implementations, which is pretty much the same issue. We could just group these into a single function in ott.tools
that can be used universally.
Describe the solution you'd like
We could implement a version of check_grads
in ott.tools
as done in the internal jax
utilities here: google/jax#2648.
Describe alternatives you've considered
We could also just use the jax internal check_grads
function directly. I don't know if there are compatibility issues there.
Additional context
The docstring for check_grads
has been added too, which makes it sound like we could just use it, but I'm not sure: google/jax#2656.