/ivon-optax

An Optax-based JAX implementation of the IVON optimizer for large-scale VI training of NNs (ICML'24 spotlight)

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

A JAX implementation of IVON

This repo contains a JAX implementation of the IVON optimizer from our paper

Variational Learning is Effective for Large Deep Networks
Y. Shen*, N. Daheim*, B. Cong, P. Nickl, G.M. Marconi, C. Bazan, R. Yokota, I. Gurevych, D. Cremers, M.E. Khan, T. Möllenhoff
International Conference on Machine Learning (ICML), 2024 (spotlight)

ArXiv: https://arxiv.org/abs/2402.17641
Blog: https://team-approx-bayes.github.io/blog/ivon/
Tutorial: https://ysngshn.github.io/research/why-ivon/

The JAX implementation of IVON is self-contained in a single file ivon.py. It makes use of components from the optax library and can be jax.jit-ted for maximum efficiency.

We also provide an official PyTorch implementation of the IVON optimizer. Experiments in our paper are obtained with the PyTorch implementation and their source code can be found here.

Quickstart

Definition

IVON optimizer can be defined via ivon.ivon() and initialized with its .init() method.

optimizer = ivon(
        lr, 
        ess, 
        hess_init, 
        beta1, 
        beta2, 
        weight_decay, 
        # ...
)
optstate = optimizer.init(params)

Appendix A of our paper provides practical guidelines for choosing IVON hyperparameters.

Usage

Similar to the usual Optax interface, the ivon.IVON optimizer instance returned by the construction function ivon.ivon() has an .init() method for initializing the optimizer states and a .step() method to compute the gradient updates.

Additionally, ivon.IVON provides two methods to support its stochastic variational inference updates: .sampled_params() to draw a weight posterior sample and additionally return its generated noise, which will be used by the .accumulate() method to estimate and accumulate the expected gradient and Hessian for the current step.

In general, a typical training step could be carried out as follows:

train_mcsamples = 1  # 1 sample is good enough, more even better
rngkey, *mc_keys = jax.random.split(rngkey, train_mcsamples+1)
# Stochastic natural gradient VI with IVON
for key in mc_keys:
    # draw IVON weight posterior sample
    psample, noise = optimizer.sampled_params(
        key, params, optstate
    )
    # get gradient for this MC sample from feed-forward + backprop
    mc_updates = ff_backprop(psample, inputs, target, ...)
    # accumulate for current sample
    optstate = optimizer.accumulate(mc_updates, optstate, noise)
# compute updates with accumulated estimations
updates, optstate = optimizer.step(optstate, params)
params = optax.apply_updates(params, updates)

Installation and dependencies

Simply copy the ivon.py file to your project folder and import it.

This file only depends on jax and optax. Visit their docs to see how they should be installed.

We also provide an example that has additional dependencies, see below.

Example: Training ResNet-18 on CIFAR-10

The folder resnet-cifar contains an example showing how the IVON optimizer can tackle image classification tasks. This code base is adapted from the Bayesian-SAM repo with permission from the author. Check out its package dependencies here.

Run training

Go to the resnet-cifar folder, run the following commands:

SGD

python train.py --alpha 0.03 --beta1 0.9 --priorprec 25 --optim sgd --dataset cifar10

This should train to around 94.8% test accuracy.

IVON

python train.py --alpha 0.5 --beta1 0.9 --beta2 0.99999 --priorprec 25 --optim ivon --hess-init 0.5 --dataset cifar10

This should train to around 95.2% test accuracy.

Run test

In resnet-cifar run the following commands. Adapt --resultsfolder to the actual folder which stores the training results.

SGD

python test.py --resultsfolder /results/cifar10_resnet18/sgd/run_0/

IVON

python test.py --resultsfolder /results/cifar10_resnet18/ivon/run_0/ --testmc 64

Tests should show that IVON with 64-sample Bayesian prediction gets significantly better uncertainty metrics.