Manifold aware pytorch.optim
.
Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more.
Work is in progress but you can already use this. Note that API might change in future releases.
geoopt.ManifoldTensor
– just as torch.Tensor with additionalmanifold
keyword argument.geoopt.ManifoldParameter
– same as above, recognized intorch.nn.Module.parameters
as correctly subclassed.
All above containers have special methods to work with them as with points on a certain manifold
.proj_()
– inplace projection on the manifold..proju(u)
– project vectoru
on the tangent space. You need to project all vectors for all methods below..inner(u, v=None)
– inner product at this point for two tangent vectors at this point. The passed vectors are not projected, they are assumed to be already projected..retr(u, t)
– retraction map following vectoru
for timet
.transp(u, t, v, *more)
– transport vectorv
(and possibly more vectors) with directionu
for timet
.retr_transp(u, t, v, *more)
– transportself
, vectorv
(and possibly more vectors) with directionu
for timet
(returns are plain tensors)
geoopt.Euclidean
– unconstrained manifold inR
with Euclidean metricgeoopt.Stiefel
– Stiefel manifold on matricesA in R^{n x p} : A^t A=I
,n >= p
geoopt.Sphere
- Sphere manifold||x||=1
geoopt.optim.RiemannianSGD
– a subclass oftorch.optim.SGD
with the same APIgeoopt.optim.RiemannianAdam
– a subclass oftorch.optim.Adam
geoopt.samplers.RSGLD
– Riemannian Stochastic Gradient Langevin Dynamicsgeoopt.samplers.RHMC
– Riemannian Hamiltonian Monte-Carlogeoopt.samplers.SGRHMC
– Stochastic Gradient Riemannian Hamiltonian Monte-Carlo