This code release accompanies the paper
Sampling-free Variational Inference for Neural Networks with Multiplicative Activation Noise
Jannik Schmitt, Stefan Roth.
In GCPR 2021.
Contact: jannik.schmitt[at]visinf.tu-darmstadt.de
This code was tested on Python 3.6.9 with PyTorch 1.4.0 and CUDA 10.2.
All requirements are summarized in requirements.txt
.
An easy way to install the required packages is setting up an Anaconda environment.
conda config --add channels pytorch
conda create -n <environment_name> --file requirements.txt
conda activate <environment_name>
For the UCI regression experiments, we store the data in comma-separated csv-files with a header in the first row.
Image Classification on ImageNet
We rescale all images so that the smaller image dimension equals 256px.
The commands required to train and evaluate the models in the paper can be found in the commands.md
file.
Our models achieve the following performance:
boston | concrete | energy | kin8 | power | wine | yacht | |
---|---|---|---|---|---|---|---|
SMFVI | -3.51 | -3.42 | -1.11 | 1.17 | -2.88 | -2.01 | -0.37 |
MNVI | -2.43 | -3.05 | -1.33 | 1.15 | -2.86 | -0.96 | -0.37 |
Model | Misclass. | NLLH | ECE | AUMRC |
---|---|---|---|---|
LeNet MFVI | 0.57% | 0.017 | 0.0021 | 8.30e-5 |
LeNet SMFVI | 0.60% | 0.017 | 0.0020 | 8.27e-5 |
LeNet MNVI | 0.55% | 0.018 | 0.0019 | 8.33e-5 |
Model | Misclass. | NLLH | ECE | AUMRC |
---|---|---|---|---|
AllConvNet MFVI | 7.72% | 0.348 | 0.0495 | 0.00898 |
AllConvNet SMFVI | 8.39% | 0.482 | 0.0586 | 0.01034 |
AllConvNet MNVI | 7.62% | 0.352 | 0.0492 | 0.00895 |
ResNet18 MFVI | 5.63% | 0.256 | 0.0372 | 0.00564 |
ResNet18 SMFVI | 5.84% | 0.233 | 0.0304 | 0.00750 |
ResNet18 MNVI | 5.60% | 0.246 | 0.0346 | 0.00553 |
Model | Misclass. | NLLH | ECE | AUMRC |
---|---|---|---|---|
ResNet18 MFVI | 26.91% | 1.271 | 0.131 | 0.0787 |
ResNet18 SMFVI | 27.18% | 1.297 | 0.136 | 0.0803 |
ResNet18 MNVI | 25.30% | 1.085 | 0.105 | 0.0740 |
Model | Misclass. | NLLH | ECE | AUMRC |
---|---|---|---|---|
ResNet18 MNVI | 31.05% | 1.276 | 0.0388 | 0.1092 |
This code is based upon Jochen Gast's Lightweight Probabilistic Deep Networks implementation.
If you use our code, please cite our GCPR 2021 paper:
@inproceedings{Schmitt:2021:SFV,
title = {Sampling-free Variational Inference for Neural Networks with Multiplicative Activation Noise},
author = {Jannik Schmitt and Stefan Roth},
booktitle = {Pattern Recognition, 43rd DAGM German Conference, DAGM GCPR 2021},
year = {2021}}