/pytorch-sso

PyTorch-SSO: Scalable Second-Order methods in PyTorch

Primary LanguagePythonMIT LicenseMIT

PyTorch-SSO (alpha release)

Scalable Second-Order methods in PyTorch.

  • Open-source library for second-order optimization and Bayesian inference.

  • An earlier iteration of this library (chainerkfac) holds the world record for large-batch training of ResNet-50 on ImageNet by Kronecker-Factored Approximate Curvature (K-FAC), scaling to batch sizes of 131K.

    • Kazuki Osawa et al, “Large-Scale Distributed Second-Order Optimization Using Kronecker-Factored Approximate Curvature for Deep Convolutional Neural Networks”, IEEE/CVF CVPR 2019.
    • [paper] [poster]
  • This library is basis for the Natural Gradient for Bayesian inference (Variational Inference) on ImageNet.

    • Kazuki Osawa et al, “Practical Deep Learning with Bayesian Principles”, NeurIPS 2019.
    • [paper (preprint)]

Scalable Second-Order Optimization

Optimizers

PyTorch-SSO provides the following optimizers.

  • Second-Order Optimization
    • torchsso.optim.SecondOrderOptimizer [source]
    • updates the parameters with the gradients pre-conditioned by the curvature of the loss function (torch.nn.functional.cross_entropy) for each param_group.
  • Variational Inference (VI)
    • torchsso.optim.VIOptimizer [source]
    • updates the posterior distribution (mean, covariance) of the parameters by using the curvature for each param_group.

Curvatures

You can specify a type of the information matrix to be used as the curvature from the following.

  • Hessian [WIP]

  • Fisher information matrix

  • Covariance matrix (empirical Fisher)

Refer Information matrices and generalization by Valentin Thomas et al. (2019) for the definitions and the properties of these information matrices.

Refer Section 6 of Optimization Methods for Large-Scale Machine Learning by L´eon Bottou et al. (2018) for a clear explanation of the second-order optimzation using these matrices as curvature.

Approximation Methods

PyTorch-SSO calculates the curvature as a layer-wise block-diagonal matrix.

You can specify the approximation method for the curvatures in each layer from the follwing.

  1. Full (No approximation)
  2. Diagonal approximation
  3. Kronecker-Factored Approximate Curvature (K-FAC)

PyTorch-SSO currently supports the following layers (Modules) in PyTorch:

Layer (Module) Full Diagonal K-FAC
torch.nn.Linear ✔️ ✔️ ✔️
torch.nn.Conv2d - ✔️ ✔️
torch.nn.BatchNorm1d/2d - ✔️ -

To apply PyTorch-SSO,

  • Setrequires_grad to True for each Module.
  • The network you define cannot contain any other modules.
  • E.g., You need to use torch.nn.functional.relu/max_pool2d instead of torch.nn.ReLU/MaxPool2d to define a ConvNet.

Distributed Training

PyTorch-SSO supports data parallelism and MC samples parallelism (for VI) for distributed training among multiple processes (GPUs).

Installation

To build PyTorch-SSO run (on a Python 3 environment)

git clone git@github.com:cybertronai/pytorch-sso.git
cd pytorch-sso
python setup.py install

To use the library

import torchsso

Additional requirements

PyTorch-SSO depends on CuPy for fast GPU computation and ChainerMN for communication. To use GPUs, you need to install the following requirements before the installation of PyTorch-SSO.

Running environment Requirements
single GPU CuPy
multiple GPUs Cupy with NCCL, MPI4py

Refer CuPy installation guide and ChainerMN installation guide for details.

Examples

Authors

Kazuki Osawa (@kazukiosawa) and Yaroslav Bulatov (@yaroslavvb)