/temperature-scaling-tensorflow

On Calibration of Modern Neural Networks - tensorflow implementation

Primary LanguagePython

Temperature Scaling tensorflow

Tensorflow implementation of On Calibration of Modern Neural Networks.

What this repo can do:

  • Train ResNet_v1_110
  • Calibrate it's output on CIFAR-10/100
  • Using temp_scaling function to calibrate any of your networks using tensorflow.

What this repo cannot do:

  • Calculate ECE (Expected Calibration Error)

Official PyTorch implementation by @gpleiss here.

Prerequisites

Data

Preparation

  • Create data/ folder, download and extract the python version from CIFAR webpage.

Train

First, train the model (ResNet 110 in this case) using default parameters:

python main.py

Check out tunable hyper-parameters:

python main.py --help

Temperature Scaling

Then, do temperature scaling to calibrate your model on the validation set.

python temp_scaling.py

Use the temp_var returned by temp_scaling function with your models logits to get calibrated output.

Notes

  • ResNet_v1_110 is trained for 250 epochs with other default parameters introduced in the original ResNet paper.
  • The identity shortcut in ResNet_v1_110 is replaced with projection shortcut, meaning there are two additional convolutional layers.
  • Validation accuracy and test accuracy on CIFAR-100 are around 70%.
  • Issues are welcome!

Resources