/AdvUnlearn

Official implementation of "Defensive Unlearning with Adversarial Training for Robust Concept Erasure in Diffusion Models"

Primary LanguageJupyter NotebookCreative Commons Attribution 4.0 InternationalCC-BY-4.0

Defensive Unlearning with Adversarial Training for Robust Concept Erasure in Diffusion Models

Our proposed robust unlearning framework, AdvUnlearn, enhances diffusion models' safety by robustly erasing unwanted concepts through adversarial training, achieving an optimal balance between concept erasure and image generation quality.

This is the code implementation of our Robust DM Unlearning Framework: AdvUnlearn, and we developed our code based on the code base of SD and ESD.

Simple Usage of AdvUnlearn Text Encoders (HuggingFace Model)

from transformers import CLIPTextModel
cache_path = ".cache"

Base model of our unlearned text encoders

model_name_or_path = "CompVis/stable-diffusion-v1-4"

text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="text_encoder", cache_dir=cache_path)

AdvUnlearn (Ours): Unlearned text encoder

model_name_or_path = "OPTML-Group/AdvUnlearn"

# Nudity-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="nudity_unlearned", cache_dir=cache_path)

# Style-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="vangogh_unlearned", cache_dir=cache_path)

# Object-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="church_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="garbage_truck_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="parachute_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="tench_unlearned", cache_dir=cache_path)

Prepare

Environment Setup

A suitable conda environment named AdvUnlearn can be created and activated with:

conda env create -f environment.yaml
conda activate AdvUnlearn

Files Download

  • Base model - SD v1.4: download it from here, and move it to models/sd-v1-4-full-ema.ckpt
  • COCO-10k (for CLIP score and FID): you can extract the image subset from COCO dataset, or you can download it from here. Then, move it to data/imgs/coco_10k

Code Implementation

Step 1: AdvUnlearn [Train]

Hyperparameters:

  • Concept to be unlearned: --prompt (e.g., 'nudity')
  • Trainable module within DM: --train_method
  • Attack generation strategy : --attack_method
  • Number of attack steps for the adversarial prompt generation: --attack_step
  • Adversarial prompting strategy: --attack_type ('prefix_k', 'replace_k' ,'add')
  • Retaining prompt dataset: --dataset_retain
  • Utility regularization parameter: --retain_loss_w

a) Command Example: Multi-step Attack

python train-scripts/AdvUnlearn.py --attack_init random --attack_step 30 --retain_train 'reg' --dataset_retain 'coco_object' --prompt 'nudity' --train_method 'text_encoder_full' --retain_loss_w 0.3

b) Command Example: Fast AT variant

python train-scripts/AdvUnlearn.py --attack_method fast_at --attack_init random --attack_step 30 --retain_train 'reg' --dataset_retain 'coco_object' --prompt 'nudity' --train_method 'text_encoder_full'   --retain_loss_w 0.3

Step 2: Attack Evaluation [Robustness Evaluation]

Follow the instruction in UnlearnDiffAtk to implement attacks on DMs with AdvUnlearn text encoder for robustness evaluation.

Step 3: Image Generation Quality Evaluation [Model Utility Evaluation]

Generate 10k images for FID & CLIP evaluation

bash jobs/fid_10k_generate.sh

Calculate FID & CLIP scores using T2IBenchmark

bash jobs/tri_quality_eval.sh

Checkpoints

ALL CKPTs for different DM unleanring tasks can be found here.

DM Unlearning Methods Nudity Van Gogh Objects
ESD (Erased Stable Diffusion)
FMN (Forget-Me-Not)
AC (Ablating Concepts)
UCE (Unified Concept Editing)
SalUn (Saliency Unlearning)
SH (ScissorHands)
ED (EraseDiff)
SPM (concept-SemiPermeable Membrane)
AdvUnlearn (Ours)

Cite Our Work

The preprint can be cited as follows:

@article{zhang2024defensive,
  title={Defensive Unlearning with Adversarial Training for Robust Concept Erasure in Diffusion Models},
  author={Zhang, Yimeng and Chen, Xin and Jia, Jinghan and Zhang, Yihua and Fan, Chongyu and Liu, Jiancheng and Hong, Mingyi and Ding, Ke and Liu, Sijia},
  journal={arXiv preprint arXiv:2405.15234},
  year={2024}
}