JuliaManifolds/ManifoldDiff.jl

Taking AD seriously

Opened this issue · 10 comments

I think I would like to get into a proper way of handling AD on manifolds. I know we have quite some issues open here (#17, JuliaManifolds/Manifolds.jl#42, #27, JuliaManifolds/Manifolds.jl#88, #29) and we also have some support already. But we have not made much changes recently. Maybe this would be a good topic to tackle next.

In general I see two ways to go

  1. Embedded manifolds and classical AD
  2. “Intrinsic, geometric AD”

The first is, what I think pymanopt does, the second is to some extend done in manopt with finite differences if I understand correctly, and I tried to do something like that for the Hessian here but did not yet find the time to continue that.

The main challenge with the first point, I think, is the conversion of an Euclidean gradient in the embedding to the Riemannian gradient, sometimes that is just a projection, sometimes that requires an adaption/calculation to adapt the metric. I have not yet understood for example all implications we would require to have something like ehess2rhess for the Hessian transform.

The main challenge for the second point is that from the building blocks (several are available in Manopt.jl, i.e. basic differentials, adjoint differentials, and gradients) we have to provide ChainRules.jl? The good thing is that we already have tangent and cotangent vectors.

For me the main questions / open points are

  • what do the current backends support?
  • What is left to do until we have a generic form of egrad2rgrad/ ehess2rhess to complete approach 1?
  • Which backends would we support?
  • for the second approach: Can we use backends for that or would we mainly need to defile new ChainRules?
  • ...and of course we might have to look more into mutating and non-mutating variants

Finally, would we want to do it here (we can start here for sure) or do we want to to a ManifoldsAD.jl package?

Let's keep this topic as an overview for this topic for now. Feel free to add more points to the list of things to do.

Good idea, incidentally this is quite high on my priority list as well (though I expect to mostly need Jacobians).

The first is, what I think pymanopt does, the second is to some extend done in manopt with finite differences if I understand correctly, and I tried to do something like that for the Hessian here but did not yet find the time to continue that.

Yes, Pymanopt does the first thing. I think geomstats also has some good ideas in this area, see for example the way they solved the problem of exp/log on sphere being very hard to AD through: https://github.com/geomstats/geomstats/blob/048a99dc9dff3f86e42025264ac0811ed6888f7c/geomstats/geometry/hypersphere.py#L533-L588 . I think this may be one of the bigger challenges, many functions we use need to be special-cased around zero, and while it's not a problem for finite differences, it is for AD. Using Taylor expansions to some degree may be necessary.

The main challenge with the first point, I think, is the conversion of an Euclidean gradient in the embedding to the Riemannian gradient, sometimes that is just a projection, sometimes that requires an adaption/calculation to adapt the metric. I have not yet understood for example all implications we would require to have something like ehess2rhess for the Hessian transform.

Gradient projections and egrad2rgrad are, in general, insufficient for ehess2rhess. Weingarten maps is the usual solution to get Riemannian Hessians from the embedded ones.

The main challenge for the second point is that from the building blocks (several are available in Manopt.jl, i.e. basic differentials, adjoint differentials, and gradients) we have to provide ChainRules.jl? The good thing is that we already have tangent and cotangent vectors.

I don't quite understand at the moment what the intrinsic geometric AD is. Probably working in charts so I think relying on ONBs of tangent spaces is not the greatest idea? I've seen this paper: https://arxiv.org/abs/1812.11592 but I don't understand why they don't have to care about the metric.

One thing I know is that ChainRules.jl has only experimental support for forward mode, and reverse mode has only very recently achieved stable API. We can't rely on only supporting ChainRules.jl AD.

One thing I know is that ChainRules.jl has only experimental support for forward mode, and reverse mode has only very recently achieved stable API.

Why would you classify ChainRules.jl's forward mode support as experimental? ChainRules just hit v1.0, which signals that its API is very stable. By comparison, ForwarDiff has not yet hit v1.0.

We can't rely on only supporting ChainRules.jl AD.

This is true if we want to support ADs like ForwardDiff and ReverseDiff, which are not likely to use ChainRules in the near future, but IMO we should start with ChainRules support, since that brings compatibility with 6 AD engines, including Diffractor.

The real issue is that none of the reverse-mode ADs that use ChainRules support mutating functions (I think), and we've built our entire interface around mutating functions. So that might mean we only support forward-mode AD for now.

It's not clear to me how/why AD on manifolds must differ from normal AD. If we're working in a chart, shouldn't there be no difference? And how do we plan to use the AD?

Note that timing-wise, it might make sense to hold off on an implementation until AbstractDifferentiation.jl is released. It aims to unify the APIs of the different AD engines, which replaces the machinery many packages like ours have to allow the user to select an AD backend.

Good idea, incidentally this is quite high on my priority list as well (though I expect to mostly need Jacobians).

I would mainly be interested in gradients and (approximate) Hessians

I think geomstats also has some good ideas in this area, see for example the way they solved the problem of exp/log on sphere being very hard to AD through: https://github.com/geomstats/geomstats/blob/048a99dc9dff3f86e42025264ac0811ed6888f7c/geomstats/geometry/hypersphere.py#L533-L588 . I think this may be one of the bigger challenges, many functions we use need to be special-cased around zero, and while it's not a problem for finite differences, it is for AD. Using Taylor expansions to some degree may be necessary.

that looks indeed quite technical. Maybe one could rather provide the differential manually for those two as I do in Manopt.jl (though therein using the framework of Jacobi fields, which might be nice to use in general anyways).

Gradient projections and egrad2rgrad are, in general, insufficient for ehess2rhess. Weingarten maps is the usual solution to get Riemannian Hessians from the embedded ones.

I know, I just do not yet completely know how to treat both functions to the generality we usually have here in our package.

I don't quite understand at the moment what the intrinsic geometric AD is. Probably working in charts so I think relying on ONBs of tangent spaces is not the greatest idea? I've seen this paper: https://arxiv.org/abs/1812.11592 but I don't understand why they don't have to care about the metric.

With intrinsic I mean a method that does not rely on the specific representation at hand (i.g. for hyperbolic should work the same for all 3 point types / representations) nor does it rely on an embedding (since that might increase the dimension) but really just works on (co-)tangent vectors, bases of (co-)tangent spaces and generic tools like vector transport. This would yield methods that never require an embedding and for best of cases only seldomly chart.
Intrinsic here in the sense that it does not use the embedding but stays within the manifold / tangent spaces.

IMO we should start with ChainRules support, since that brings compatibility with 6 AD engines, including Diffractor.

That was my take-home-message from what I learned about ChainRules at JuliaCon.

The real issue is that none of the reverse-mode ADs that use ChainRules support mutating functions (I think), and we've built our entire interface around mutating functions. So that might mean we only support forward-mode AD for now.

yeah, I am also not sure about this yet, either; we need the mutating functions for speed here, sure. But I think my first issue would actually be the embedding vs intrinsic approaches (see last comment)

It's not clear to me how/why AD on manifolds must differ from normal AD. If we're working in a chart, shouldn't there be no difference? And how do we plan to use the AD?

As I wrote, there would be two – maybe then 3 – approaches

  1. embedded, and project / transform back (for isometrically embedded it is just project, otherwise more, especially for higher order stuff like Hessians)
  2. intrinsic (see last comment)
  3. through charts (on the domain for real valued functions on both domain and range for manifold-valued functions) then we could to classic AD but we have to do a lot for the charts and get overhead there I fear.

Note that timing-wise, it might make sense to hold off on an implementation until AbstractDifferentiation.jl is released. It aims to unify the APIs of the different AD engines, which replaces the machinery many packages like ours have to allow the user to select an AD backend.

Sounds reasonable, sure. We could still sketch ideas and start some code, maybe. As I just wrote, the first would be to figure out which modes we have (I think now we are at 3 different ways to do it).

Why would you classify ChainRules.jl's forward mode support as experimental? ChainRules just hit v1.0, which signals that its API is very stable. By comparison, ForwarDiff has not yet hit v1.0.

Mostly because I don't know of any stable forward mode AD library that uses ChainRules.jl. Is there one? ForwardDiff.jl's last breaking release was in 2018, so I'd say it's stable.

The real issue is that none of the reverse-mode ADs that use ChainRules support mutating functions (I think), and we've built our entire interface around mutating functions. So that might mean we only support forward-mode AD for now.

I think we could just gradually introduce non-mutating variants as we need them.

It's not clear to me how/why AD on manifolds must differ from normal AD. If we're working in a chart, shouldn't there be no difference? And how do we plan to use the AD?

The thing is, we don't want to be constantly working in a chart. one we get a gradient (for example), we want to perform retraction using a closed-form formula, and most of our formulas work in an embedding. Even if we never work in an embedding, we may need to switch a chart. There was some recent work on chart-based optimization and normalizing flows on manifolds that attempt to address these issues: https://arxiv.org/abs/1909.09501 and https://arxiv.org/abs/2006.10254 .

AD seems to be mostly useful for gradients and Hessians in optimization. I'm going to experiment with Jacobians for continuous normalizing flows.

Note that timing-wise, it might make sense to hold off on an implementation until AbstractDifferentiation.jl is released. It aims to unify the APIs of the different AD engines, which replaces the machinery many packages like ours have to allow the user to select an AD backend.

I think there are many other issues we can solve without waiting for that.

that looks indeed quite technical. Maybe one could rather provide the differential manually for those two as I do in Manopt.jl (though therein using the framework of Jacobi fields, which might be nice to use in general anyways).

OK but how does one manually provide such differentials to AD? I've done such things for ForwardDiff.jl but the method only works for specific AD libraries (though ChainRules.jl may help?) and is quite ugly.

With intrinsic I mean a method that does not rely on the specific representation at hand (i.g. for hyperbolic should work the same for all 3 point types / representations) nor does it rely on an embedding (since that might increase the dimension) but really just works on (co-)tangent vectors, bases of (co-)tangent spaces and generic tools like vector transport. This would yield methods that never require an embedding and for best of cases only seldomly chart.
Intrinsic here in the sense that it does not use the embedding but stays within the manifold / tangent spaces.

That's a very ambitious goal, and I don't know what generic tools are sufficient to solve this without making it too slow.

I did not say intrinsic is easy, nor that there is much work done in this direction yet – but it would be really cool if something like that would work :)

It's not clear to me how/why AD on manifolds must differ from normal AD. If we're working in a chart, shouldn't there be no difference? And how do we plan to use the AD?

As I wrote, there would be two – maybe then 3 – approaches

1. embedded, and project / transform back (for isometrically embedded it is just project, otherwise more, especially for higher order stuff like Hessians)

2. intrinsic (see last comment)

3. through charts (on the domain for real valued functions on both domain and range for manifold-valued functions) then we could to classic AD but we have to do a lot for the charts and get overhead there I fear.

I'm still not clear on this. Is there a write-up somewhere of how AD on manifolds should/must differ from standard AD? Or put another way, what quantities are you hoping to get out of AD? Because standard AD will give you "tangents" (derivatives of the real values stored in your points wrt some upstream real scalar) and "cotangents" (derivatives of some downstream real scalar wrt the real values stored in your points). How does this differ from what you want from a manifold AD?

Oh, the result types will not differ, also inputs will not differ. We will work with (co)tangents for sure.

The question is more about the “computational path” to take.

Let's look at the first variant (embedding) and an example – maybe an easy one – isometrically embedded – the sphere.

If we have a function like the Rayleigh quotient on the sphere ( f(x) = x'*A*x no need to divide by the norm of x on the sphere :)) we can do classical AD (or just compute by hand) to derive for example the (Euclidean) gradient in the embedding. Projecting that onto the tangent space yields the gradient.
Now if the embedding is not isometric (i.e. the embedded manifold does not just use the restriction of the metric in the embedding to a tangent space as an inner product thereon) this is more involved, since after projection you also have to account for the “change of the metric”. Thats what I paraphrased as “transform back”.
Also if the embedding is quite high dimensional (not just n -> n+1 but an embedding n -> m dimensions with m>>n) this approach might introduce allocations/memory usage. Still, this is what in most cases is done today.

A further disadvantage is, that a manifold might not have such an embedding...

Similarly for the third approach – lets take again the Rayleigh quotient. We could define it as a function in a chart, i.e. from R^n (for S^n) in stead of R^[n+1} in the embedding as F = f(g(c)), where g is a chart and c are the coordinates of x in the chart (.
for a function between manifolds we would take a chart in the range and do F = h^{-1}(f(g(c))).
The same could be done with parametrisation (inverses of charts) then we would use h and g^{-1}.
For these F we can do the classical AD games again but we have to use the Differential (adjoint differential) of the charts/parametrisations, too, to get to the results as before. So we have a post processing. Also we have to find suitable charts, i.e. charts around x and around f(x) respectively. These two “surrounding steps” as well as having a double chain rule more, might make this approach more involved/complicated.

For the last one, i.e. we have Df or D*f or grad f for simple building blocks/functions on a manifold and define chain rules for these functions and derive an intrinsic AD (might involve the Riemannian curvature tensor and such, not sure yet), we would maybe not have those overheads (up to curvature). There is this one paper from 2018 describing this a little bit but I have not seen this in source code.

So in the end all 3 hopefully do the same, and we might do some of them (just maybe not the chart one? I am not sure). They all have the same input and output, but their inner mechanisms are different.

Does this help?

Here is the quick sketch or a rrule for distance I've made on Slack. I think it's worth saving from the Slack hole:

function ChainRulesCore.rrule(::typeof(distance), M::AbstractManifold, p, q)
    d = distance(M, p, q)
    function distance_pullback(Ȳ)
        return NoTangent(), NoTangent(), -/d) * log(M, p, q), -/d) * log(M, q, p)
    end
    return d, distance_pullback
end

And note that for the extrinsic approach, the PRs JuliaManifolds/Manifolds.jl#423 and JuliaManifolds/Manifolds.jl#427 will bring the desired functionality at least for gradients (Hessians is the next thing I want to get into) – especially they provide a more general approach than Manopt/Matlabs egrad2rgrad, in the sense that it is a little more structured (with respect to the metric convergence) and can be coupled with any AD in the embedding (we are working on tidying that up a little, too.

This is quite some step to the first approach mentioned