Pre-Train Your Loss! High-Performance Transfer Learning with Bayesian Neural Networks and Pre-Trained Priors

This repository contains an easy-to-use PyTorch implementation of methods described in Pre-Train Your Loss! High-Performance Transfer Learning with Bayesian Neural Networks and Pre-Trained Priors by Ravid Shwartz-Ziv, Micah Goldblum, Hossein Souri, Sanyam Kapoor, Chen Zhu, Yann LeCun, and Andrew Gordon Wilson. Preview

Summary

Idea: We can transfer much more than an initialization. Knowledge of the source task should affect the locations and shape of optima on the downstream task.

Approach: Infer a posterior on the source task to re-scale as an informative prior on the downstream task.

Results: Significantly improved performance over standard transfer learning and fine tuning, with minimal overhead.

Overview

Our Bayesian transfer learning framework transfers knowledge from pre-training to downstream tasks. To up-weight parameter settings consistent with a pre-training loss function, we fit a probability distribution over the parameters of feature extractors to a pre-training loss function and rescale it as a prior. By adopting a learned prior, we alter the downstream loss surface and its optimal locations. By contrast, typical transfer learning methods only use a pre-trained initialization.

Preview

Our Bayesian transfer learning pipeline uses only easy-to-implement existing tools. In our experiments, Bayesian transfer learning outperforms both SGD-based transfer learning and non-learned Bayesian inference. A schematic of our framework is found below.
This repo contains the code for extracting your prior parameters and applying them to a downstream task using Bayesian inference. The downstream tasks include both image classification and image segmentation.

Dependencies:

  • torch >= 1.8.1
  • torchvision >= 0.9.1
  • pytorch-lightning >= 1.4.7

For the complete list of requirements see requirements.txt.

Prepare Datasets:

For your convenience, we have provided the python scripts for downloading and organizing the Oxford Flowers 102 and Oxford-IIIT Pet datasets. The python scripts can be found here.

Usage:

Use prior_run_jobs.py both to learn priors from pre-trained checkpoints and also to perform inference on downstream tasks.

python prior_run_jobs.py --job=<JOB> \
                         --prior_type=<PRIOR_TYPE> \
                         --data_dir=<DATA_DIR> \
                         --train_dataset=<TRAIN_DATASET> \
                         --val_dataset=<VAL_DATASET> \
                         --pytorch_pretrain=<PYTORCH_PRETRAIN> \ 
                         --prior_scale=<PRIOR_SCALE> \ 
                         --num_of_train_examples=<NUM_OF_TRAIN_EXAMPLES> \ 
                         --weights_path=<WEIGHTS_PATH> \ 
                         --number_of_samples_prior=<NUMBER_OF_SAMPLES_PRIOR> \ 
                         --encoder=<ENCODER> \ 

Parameters:

  • JOB - set prior to learn a prior or supervised_bayesian_learning to perform inference on downstream tasks.

  • PRIOR_TYPE --type of prior used for inference on a downstream task:

            - `normal` - zero-mean isotropic Gaussian prior
            - `shifted_gaussian` - learned prior
    
  • PRIOR_PATH - path for the file to load the learned prior. The file should contain model weight, mean, variance, and cov_factor fields. It must fit to the following format: prior_path_model.pt, prior_path_mean.pt, prior_path_variance.pt, prior_path_covmat.pt. You can download the pretrained priors here.

  • DATA_DIR - path which contains the data

  • TRAIN_DATASET - dataset for training

  • VAL_DATASET - dataset for validation

  • PYTORCH_PRETRAIN - if we would like to load the weights from a torchvision pretrained model

  • PRIOR_SCALE - parameter for re-scaling the prior

  • NUM_OF_TRAIN_EXAMPLES - number of training samples on which to train our model

  • WEIGHTS_PATH - path for loading pre-train weights

  • NUMBER_OF_SAMPLES_PRIOR - number of samples for fitting the covariance of the prior

  • ENCODER - base network architecture. The options include most models supported by torchvision.

For the full list of arguments, see priorBox/options.py. All optional arguments for Bayesian learning are listed here and optional arguments for learning a prior are listed here.

Our Pre-Trained Priors:

Our learned priors can be found here. The priors include torchvision ResNet-50 and ResNet-101 as well as SimCLR ResNet-50, all trained on ImageNet. To use these for downstream tasks, pass the argument --prior_path along with the path for the prior when running prior_run_jobs.py. Please note that the path should contain model weight, mean, variance, and cov_factor fields. Also, it must fit to the following format: "prior_path"_model.pt, "prior_path"_mean.pt, "prior_path"_variance.pt, "prior_path"_covmat.pt.

How to Cite:

@article{shwartz2022pre,
  title={Pre-Train Your Loss: Easy Bayesian Transfer Learning with Informative Priors},
  author={Shwartz-Ziv, Ravid and Goldblum, Micah and Souri, Hossein and Kapoor, Sanyam and Zhu, Chen and LeCun, Yann and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:2205.10279},
  year={2022}
}