/IDC

Code for Efficient Image-to-Image Diffusion Classifier for Adversarial Robustness

Primary LanguagePython

Efficient Image-to-Image Diffusion Classifier for Adversarial Robustness

arxiv:https://arxiv.org/pdf/2408.08502

Hefei Mei, Minjing Dong, Chang Xu

Example Image

Requirements

We evaluate the robustness of the IDC based on torchattacks (PGD, FGSM, MIFGSM, CW, AutoAttack) and DiffPure (only for BPDA+EOT). If you only need the classifier or the former robustness evaluation, you can follow the Requirements in BBDM and install the torchattacks. Otherwise, you can install the following combined environment.

  • CUDA=10.2

  • Python 3.9

  1. Create environment:
conda create -n IDC python=3.9
conda activate IDC
pip install -r requirements
  1. Put the file _pytree.py in anaconda/envs/IDC/lib/python3.9/site-packages/torch/utils

Train

  • cd IDC/

If you wish to train from the beginning

# For CIFAR-10
python main.py --config configs/Template-IDC-cifar10.yaml
--train
--sample_at_start
--save_top
--gpu_ids 0,1,2,3
# For CIFAR-100
python main.py --config configs/Template-IDC-cifar100.yaml
--train
--sample_at_start
--save_top
--gpu_ids 0,1,2,3

If you wish to continue training

# For CIFAR-10
python main.py --config configs/Template-IDC-cifar10.yaml
--train
--sample_at_start
--save_top
--gpu_ids 0,1,2,3
--resume_model path/to/model_ckpt
--resume_optim path/to/optim_ckpt
# For CIFAR-100
python main.py --config configs/Template-IDC-cifar100.yaml
--train
--sample_at_start
--save_top
--gpu_ids 0,1,2,3
--resume_model path/to/model_ckpt
--resume_optim path/to/optim_ckpt
  • We use auto_lr_scheduler=True with batch_size 256 (64*4), if you will train the model with other batch_size and lr, you can set use auto_lr_scheduler=False for similar accuracy. 

Test

  • cd IDC/
# For CIFAR-10
bash test_torchattacks_cifar10.sh
# For CIFAR-100
bash test_torchattacks_cifar100.sh
  • cd IDC-BPDA/
# For CIFAR-10
bash test_bpda_cifar10.sh
# For CIFAR-100
bash test_bpda_cifar100.sh

Pre-trained Model

Dataset Log Checkpoint
CIFAR-10 log_cifar10 [Google, Quark] ckpt_cifar10 [Google, Quark]
CIFAR-100 log_cifar100 [Google, Quark] ckpt_cifar100 [Google, Quark]

Acknowledgement

Our code is implemented based on BBDM, DiffPure and torchattacks, thanks for their excellent works.