Apply ensemble classification model to detect out-of-distribution samples.
Follow Simple and scalable predictive uncertainty estimation using deep ensembles, implemented the ensemble model on classification task with PyTorch. Experiments show the effect of out-of-distribution detection under different loss functions and optimizers.
Three datasets are used to demonstration purpose, including
Two of the most conventional and popular optimizers are tried in the experiments:
- Adam Optimizer
- SGD Optimizer
Two loss functions are explored:
- Brier Score
- Softmax Cross Entropy
Details can be found in the slide.
An ensemble model consists of 5 single NNs is trained on MNIST training dataset, and then tested on MNIST test dataset, FASHION-MNIST and NOT-MNIST to demonstrate the out-of-distribution detection effect. The model is trained a total of 20 epochs. The following figures show two metrics-- testing accuray and averaged probability score of predicted labels. Both the ensemble net and single net are evaluated with these two metrics.
- Brier Score and Adam Optimizer
- Brier Score and SGD Optimizer
- Softmax Cross Entropy and Adam Optimizer
- Ensemble model is able to gain better accuracy, delay the overfitting, and shows higher uncertainty when it comes to out-of-distribution samples.
- SGD optimizer is better than Adam optimizer to detect out-of-distribution examples.
- To make a single model more resilient to out-of-distribution examples, avoid overfitting the model.
Fork to add more cases such as regression task, adversarial learning, and/or more related experiments.
[1] Lakshminarayanan, Balaji, Alexander Pritzel, and Charles Blundell. "Simple and scalable predictive uncertainty estimation using deep ensembles." Advances in neural information processing systems. 2017.