geoopt

Manifold aware `pytorch.optim`

.

Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more.

Installation

Make sure you have pytorch>=1.9.0 installed

There are two ways to install geoopt:

- GitHub (preferred so far) due to active development

`pip install git+https://github.com/geoopt/geoopt.git`

- pypi (this might be significantly behind master branch)

`pip install geoopt`

The preferred way to install geoopt will change once stable project stage is achieved. Now, pypi is behind master as we actively develop and implement new features.

PyTorch Support

Geoopt officially supports **2 latest stable versions** (1.9.0 so far) of pytorch upstream or the latest major release. We also test (TODO: there were complications with github workflows, need help) against the nightly build, but do not be 100% sure about compatibility. As for older pytorch versions, you may use it on your own risk (do not forget to run tests).

What is done so far

Work is in progress but you can already use this. Note that API might change in future releases.

Tensors

`geoopt.ManifoldTensor`

– just as torch.Tensor with additional`manifold`

keyword argument.`geoopt.ManifoldParameter`

– same as above, recognized in`torch.nn.Module.parameters`

as correctly subclassed.

All above containers have special methods to work with them as with points on a certain manifold

`.proj_()`

– inplace projection on the manifold.`.proju(u)`

– project vector`u`

on the tangent space. You need to project all vectors for all methods below.`.egrad2rgrad(u)`

– project gradient`u`

on Riemannian manifold`.inner(u, v=None)`

– inner product at this point for two**tangent**vectors at this point. The passed vectors are not projected, they are assumed to be already projected.`.retr(u)`

– retraction map following vector`u`

`.expmap(u)`

– exponential map following vector`u`

(if expmap is not available in closed form, best approximation is used)`.transp(v, u)`

– transport vector`v`

with direction`u`

`.retr_transp(v, u)`

– transport`self`

, vector`v`

(and possibly more vectors) with direction`u`

(returns are plain tensors)

Manifolds

`geoopt.Euclidean`

– unconstrained manifold in`R`

with Euclidean metric`geoopt.Stiefel`

– Stiefel manifold on matrices`A in R^{n x p} : A^t A=I`

,`n >= p`

`geoopt.Sphere`

– Sphere manifold`||x||=1`

`geoopt.BirkhoffPolytope`

– manifold of Doubly Stochastic matrices`geoopt.Stereographic`

– Constant curvature stereographic projection model`geoopt.SphereProjection`

– Sphere stereographic projection model`geoopt.PoincareBall`

– Poincare ball model`geoopt.Lorentz`

– Hyperboloid model`geoopt.ProductManifold`

– Product manifold constructor`geoopt.Scaled`

– Scaled version of the manifold. Similar to Learning Mixed-Curvature Representations in Product Spaces if combined with`ProductManifold`

`geoopt.SymmetricPositiveDefinite`

– SPD matrix manifold`geoopt.UpperHalf`

– Siegel Upper half manifold. Supports Riemannian and Finsler metrics, as in Symmetric Spaces for Graph Embeddings: A Finsler-Riemannian Approach.`geoopt.BoundedDomain`

– Siegel Bounded domain manifold. Supports Riemannian and Finsler metrics.

All manifolds implement methods necessary to manipulate tensors on manifolds and tangent vectors to be used in general purpose. See more in documentation.

Optimizers

`geoopt.optim.RiemannianSGD`

– a subclass of`torch.optim.SGD`

with the same API`geoopt.optim.RiemannianAdam`

– a subclass of`torch.optim.Adam`

Samplers

`geoopt.samplers.RSGLD`

– Riemannian Stochastic Gradient Langevin Dynamics`geoopt.samplers.RHMC`

– Riemannian Hamiltonian Monte-Carlo`geoopt.samplers.SGRHMC`

– Stochastic Gradient Riemannian Hamiltonian Monte-Carlo

Citing Geoopt

If you find this project useful in your research, please kindly add this bibtex entry in references and cite.

```
@misc{geoopt2020kochurov,
title={Geoopt: Riemannian Optimization in PyTorch},
author={Max Kochurov and Rasul Karimov and Serge Kozlukov},
year={2020},
eprint={2005.02819},
archivePrefix={arXiv},
primaryClass={cs.CG}
}
```