/PosthocUQ

Primary LanguagePythonMIT LicenseMIT

Post-hoc Uncertainty Learning using a Dirichlet Meta-Model

This repository contains official implementation of AAAI 2023 paper Post-hoc Uncertainty Learning using a Dirichlet Meta-Model.

Requirements

  • python == 3.8.8
  • pytorch == 1.10.0
  • torchvision == 0.11.1
  • numpy, scipy, sklearn, random, argparse, csv, os, time, sys

Usage

Training base-model

  • LeNet/MNIST
python train_base_model.py --model='LeNet_BaseModel' --name='MNIST' --dataset='MNIST' --lr=1e-2 --seed=0 --decay=5e-4 --epoch=20
  • VGG16/CIFAR10
python train_base_model.py --model='VGG16_BaseModel' --name='CIFAR10' --dataset='CIFAR10' --lr=1e-1 --seed=0 --decay=1e-4 --epoch=200
  • WideResNet-16/CIFAR100
python train_base_model.py --model='WideResNet_BaseModel' --name='CIFAR100' --dataset='CIFAR100' --lr=1e-1 --seed=0 --decay=1e-4 --epoch=200

Training meta-model

  • For different UQ tasks, simply change the "name", such as --name='CIFAR10_OOD' for OOD detection, and --name='CIFAR10_miss' for Misclassfication.
  • MNIST
python train_meta_model_combine.py --base_model='LeNet_BaseModel' --base_epoch=20 --meta_model='LeNet_MetaModel_combine' --name='MNIST_OOD' --dataset='MNIST' --lr=1e-1 --seed_trail=0 --decay=1e-4 --epoch=20 --lambda_KL=1e-1 
  • CIFAR10
python train_meta_model_combine.py --base_model='VGG16_BaseModel' --meta_model='VGG16_MetaModel_combine' --name='CIFAR10_OOD' --dataset='CIFAR10' --lr=1e-3 --seed_trail=0 --decay=1e-4 --epoch=20 --lambda_KL=1e-3 
  • CIFAR100
python train_meta_model_combine.py --base_model='WideResNet_BaseModel' --meta_model='WideResNet_MetaModel_combine' --name='CIFAR100_OOD' --dataset='CIFAR100' --lr=1e-2 --seed_trail=0 --decay=1e-4 --epoch=20 --lambda_KL=1e-3 

Evaluate

  • MNIST
python eval_meta_model.py --base_model='LeNet_BaseModel' --meta_model='LeNet_MetaModel_combine' --name='MNIST_OOD' --dataset='MNIST' --base_epoch=20
  • CIFAR10
python eval_meta_model.py --base_model='VGG16_BaseModel' --meta_model='VGG16_MetaModel_combine' --name='CIFAR10_OOD' --dataset='CIFAR10'
  • CIFAR100
python eval_meta_model.py --base_model='WideResNet_BaseModel' --meta_model='WideResNet_MetaModel_combine' --name='CIFAR100_OOD' --dataset='CIFAR100'

Datasets

  • Please manually download LSUN and Tiny ImageNet datasets.
  • The dataloader automatically downloads other datasets.

Reference

This code is based on the following repositories: