A unified interface instead of the current naming convention dispatching on functions.
Opened this issue · 5 comments
As an improvement of #17 even, we should introduce a better interface for gradient, differential, Jacobin, Hessian and proximal map (also partly discussed in #37).
Proposal of an interface
Any of the interface functions should have as first argument the manifold, as second the function it provides the diff for, if no further argument is provided, a function is returned, if a third – point – is provided, the prox/gradient /Jacobian is evaluatioed, for the Differential, its Adjoint or the Hessian, one further argument would be necessary.
The same holds for the in-place variants.
The implementation should be changed to use these interfaces instead of the current (little clumsy) naming scheme. Maybe this is best explained in examples.
Examples
gradient(M, distance; parameters=(q,)) # gradient of distance(M, p, q) as a functions p -> X /in TpM)
gradient(M, distance, p; parameters=(q,)) # the previous evaluated at p
gradient!(M, distance; parameters=(q,)) # the previous in-place variant (p,X) -> X that works in-place of X
gradient!(M, distance, p, X; parameters=(q,)) # the previous evaluated in-place of X
where one open points would be to discuss a good interface for parameters like the fixed value q (or geodesic parameters in the derivative and such).
Good idea. I would prefer, though, a design more similar to https://github.com/JuliaFirstOrder/ProximalOperators.jl :
gradient(M, FixArg2(distance, q), p) # gradient of distance(M, p, q) as a functions p -> X /in TpM)
gradient!(M, FixArg2(distance, q), X, p) # the previous evaluated in-place of X
and also
differential(M, FixArg1(exp, q), X, Y) # replacement for differential_exp_argument
differential(M, FixArg2(exp, X), p, Y) # replacement for differential_exp_basepoint
Maybe just with better names than FixArg1
or FixArg2
, but that would be my general idea.
Note also https://juliadiff.org/ChainRulesCore.jl/stable/ , though they are interested in computing differentials wrt. all inputs at once.
That sounds good. Yeah. I think we should think about the naming a bit, but I like the idea, mine was more like ad-hoc and the parameters kwarg is also really not a good idea ;)
Could we also have FixArgument
as a generic thing not with the number in the name? Then I like also the name.
Cool. Sure, we can have a generic FixArgument
but now I think that maybe it should be "fix all arguments except nth one" kind of thing.
That might be better, I was thinking of that as well already when you have the retraction method or such involved.
To use that with grad/prox/diff one could go for WithRespectToArgument
(of course properly documented because it only makes sense in context):
gradient(M, WithRespectToArgument(2,diostance,q))
or gradient(M, WithRespectToArgument(2,diostance,q), p)
to evaluate the gradient at p
of course the 2 is a bit relative if we do not count M as an argument for example. And maybe providing the others (like q
) afterwards might also be a bit strange since one has to provide all others except the second (nth in general).
Thinking about the naming a bit, I think I like the last one proposed here indeed – I am just not sure how to realise both the struct (WithRespectToArument
) nor how to actually implement the gradient / prox / diff / adjoint diff). Do we have other packages we can use as an inspiration here?