Tensorflow implementation of On Calibration of Modern Neural Networks.
What this repo can do:
- Train ResNet_v1_110
- Calibrate it's output on CIFAR-10/100
- Using
temp_scaling
function to calibrate any of your networks using tensorflow.
What this repo cannot do:
- Calculate ECE (Expected Calibration Error)
Official PyTorch implementation by @gpleiss here.
- Python 3.5
- NumPy
- TensorFlow 1.8
- Create
data/
folder, download and extract the python version from CIFAR webpage.
First, train the model (ResNet 110 in this case) using default parameters:
python main.py
Check out tunable hyper-parameters:
python main.py --help
Then, do temperature scaling to calibrate your model on the validation set.
python temp_scaling.py
Use the temp_var
returned by temp_scaling
function with your models logits to get calibrated output.
- ResNet_v1_110 is trained for 250 epochs with other default parameters introduced in the original ResNet paper.
- The identity shortcut in ResNet_v1_110 is replaced with projection shortcut, meaning there are two additional convolutional layers.
- Validation accuracy and test accuracy on CIFAR-100 are around 70%.
- Issues are welcome!