This is the official repo for
Unsupervised Out-of-distribution Detection with Diffusion Inpainting (ICML 2023)
by Zhenzhen Liu*, Jin Peng Zhou*, Yufan Wang, and Kilian Q. Weinberger
*Equal Contribution
Unsupervised out-of-distribution detection (OOD) seeks to identify out-of-domain data by learning only from unlabeled in-domain data. We present a novel approach for this task -- Lift, Map, Detect (LMD) -- that leverages recent advancement in diffusion models. Diffusion models are one type of generative models. At their core, they learn an iterative denoising process that gradually maps a noisy image closer to their training manifolds. LMD leverages this intuition for OOD detection. Specifically, LMD lifts an image off its original manifold by corrupting it, and maps it towards the in-domain manifold with a diffusion model. For an OOD image, the mapped image would have a large distance away from its original manifold, and LMD would identify it as OOD accordingly. We show through extensive experiments that LMD achieves competitive performance across a broad variety of datasets.
@article{liu2023unsupervised,
title={Unsupervised Out-of-Distribution Detection with Diffusion Inpainting},
author={Liu, Zhenzhen and Zhou, Jin Peng and Wang, Yufan and Weinberger, Kilian Q},
journal={arXiv preprint arXiv:2302.10326},
year={2023}
}
The environment can be created by running the commands in create_environment.sh
with Conda and pip.
To install the SimCLR similarity metric, please see this and this. The checkpoint r50_1x_sk1.pth
should be put under pretrained/
.
To train a DDPM model on datasets such as CIFAR10:
python main.py --workdir results/cifar10/ --config configs/subvp/cifar10_ddpm_continuous.py --mode train
After training the model or with a pretrained checkpoint, we can perform inpainting by:
python recon.py --config configs/subvp/cifar10_ddpm_continuous.py --ckpt_path results/cifar10/checkpoints/checkpoint_20.pth --in_domain CIFAR10 --out_of_domain SVHN --batch_size 200 --mask_type checkerboard_alt --mask_num_blocks 8 --reps_per_image 10 --workdir results/cifar10/CIFAR10_vs_SVHN/
To evaluate ROC-AUC with LPIPS similarity metric, simply run:
python detect.py --result_path results/cifar10/CIFAR10_vs_SVHN/checkerboard_alt_blocks8_reps10/ --reps 10 --metric LPIPS
Most datasets will be automatically downloaded. For ImageNet and CelebA-HQ datasets, a manual download is needed. After downloading them, a txt file that contains the path to individual images on each line is needed to run inpainting in recon.py
.
This work built upon excellent open-source implementations from Yang Song. Specifically, we adapted his Pytorch DDPM implementation (link). We also adapted great work from SimMIM and SimCLRv2-Pytorch.