/S-FSVI

Code for the paper 'Continual Learning via Sequential Function-Space Variational Inference'

Primary LanguageJupyter NotebookMIT LicenseMIT

Continual Learning via Sequential Function-Space Variational Inference (S-FSVI)

This repository contains the official implementation for

Continual Learning via Sequential Function-Space Variational Inference; Tim G. J. Rudner, Freddie Bickford Smith, Qixuan Feng, Yee Whye Teh, Yarin Gal. ICML 2022.

Abstract: Sequential Bayesian inference over predictive functions is a natural framework for continual learning from streams of data. However, applying it to neural networks has proved challenging in practice. Addressing the drawbacks of existing techniques, we propose an optimization objective derived by formulating continual learning as sequential function-space variational inference. In contrast to existing methods that regularize neural network parameters directly, this objective allows parameters to vary widely during training, enabling better adaptation to new tasks. Compared to objectives that directly regularize neural network predictions, the proposed objective allows for more flexible variational distributions and more effective regularization. We demonstrate that, across a range of task sequences, neural networks trained via sequential function-space variational inference achieve better predictive accuracy than networks trained with related methods while depending less on maintaining a set of representative points from previous tasks.

View Paper

In particular, this codebase includes:

  • An implementation of the sequential function-space variational objective [1];
  • Notebooks that reproduce the results in the paper;
  • A general, easy-to-extend continual learning training and evaluation protocol;
  • A set of framework-agnostic dataloader methods for widely used continual learning tasks;

[1] The implementation is based on the approximation proposed in Tractable Function-Space Variational Inference in Bayesian Neural Networks (Rudner et al., 2022).



Figure 1
Figure 1. Schematic of sequential function-space variational inference.

Installation

To install requirements:

$ conda env update -f environment.yml
$ conda activate fsvi

This environment includes all necessary dependencies.

To create an fsvi executable to run experiments, run pip install -e ..

Reproducing results

Split MNIST, Permuted MNIST, and Split FashionMNIST

Method Split MNIST (MH)
Open In Colab
Split FashionMNIST (MH)
Open In Colab
Permuted MNIST (SH)
Open In Colab
Split MNIST (SH)
Open In Colab
S-FSVI (ours) 99.54% ± 0.04 99.05% ± 0.03 95.76% ± 0.02 92.87% ± 0.14
S-FSVI (larger networks) 99.76% ± 0.00 98.50% ± 0.11 97.50% ± 0.01 93.38% ± 0.10
S-FSVI (no coreset) 99.62% ± 0.01 99.17% ± 0.06 84.06% ± 0.46 20.15% ± 0.52
S-FSVI (minimal coreset [2]) NA [3] NA [3] 89.59% ± 0.30 51.44% ± 1.22

[2] "Minimal coresets" are constructed by randomly selecting one data point per class for a given task.

[3] Since S-FSVI already performs well without a coreset, the minimal coreset option is not useful.

Split CIFAR

Method Split CIFAR (MH)
Open In Colab
S-FSVI [4] 77.57% ± 0.84

Sequential Omniglot

Method Sequential Omniglot (MH)
Open In Colab
S-FSVI [4] 83.29% ± 1.2

[4] To speed up training and reduce the memory requirements, only the variance parameters in the final layer of the network are learned variationally and the linearization is computed on the final layer only.

2D Visualization

This notebook Open In Colab demonstrates continual learning via S-FSVI on a sequence of five binary-classification tasks in a 2D input space.

Figure 2
Figure 2. Predictive distributions of a model trained via S-FSVI on tasks 1-5.

Adding new methods or tasks

  • To implement a new method, create a file method_cl_methodname.py in /benchmarking. For reference, see /benchmarking/method_cl_template.py and /benchmarking/method_cl_fsvi.py.
  • To implement a new dataloader, add a new method to benchmarking/data_loaders.

Citation

@InProceedings{rudner2022continual,
      author={Tim G. J. Rudner and Freddie Bickford Smith and Qixuan Feng and Yee Whye Teh and Yarin Gal},
      title = {{C}ontinual {L}earning via {S}equential {F}unction-{S}pace {V}ariational {I}nference},
      booktitle ={Proceedings of the 39th International Conference on Machine Learning},
      year = {2022},
      series ={Proceedings of Machine Learning Research},
      publisher ={PMLR},
}

Please cite our paper if you use this code in your own work.