/VDPG

VDPG: Adapting to Distribution Shift by Visual Domain Prompt Generation (ICLR 2024)

Primary LanguagePython

VDPG: Adapting to Distribution Shift by Visual Domain Prompt Generation (ICLR 2024)

Paper / Project

💡 Abstract

In this paper, we aim to adapt a model at test-time using a few unlabeled data to address distribution shifts. To tackle the challenges of extracting domain knowledge from a limited amount of data, it is crucial to utilize correlated information from pre-trained backbones and source domains. Previous studies fail to utilize recent foundation models with strong out-of-distribution generalization. Additionally, domain-centric designs are not flavored in their works. Furthermore, they employ the process of modelling source domains and the process of learning to adapt independently into disjoint training stages. In this work, we propose an approach on top of the pre-computed features of the foundation model. Specifically, we build a knowledge bank to learn the transferable knowledge from source domains. Conditioned on few-shot target data, we introduce a domain prompt generator to condense the knowledge bank into a domain-specific prompt. The domain prompt then directs the visual features towards a particular domain via a guidance module. Moreover, we propose a domain-aware contrastive loss and employ meta-learning to facilitate domain knowledge extraction. Extensive experiments are conducted to validate the domain knowledge extraction. The proposed method outperforms previous work on 5 large-scale benchmarks including WILDS and DomainNet.

💡 News

  • [2024/05] Paper is on Arxiv.
  • [2024/05] VDPG is released.
  • [2024/01] VDPG has been accepted by ICLR 2024

Installation

Requirements

pip install -r torch21_cu118_py39.yml

Download CLIP's pretrained weights

From open_clip to

./modelzoo/openai_clip

Download datasets

WILDS

Please follow the download instructions provided by WILDS benchmark to download iWildCam, Camelyon17, FMoW and ProvertyMap to the folder of ./data

DomainNet

  1. Download the official DomainNet benchmark to ./data
  2. Generate a metadata.csv for DomainNet by running the WILDS preprocessing script.

Noted, we evaluate the official test split of DomainNet instead of the random split of the target domain as in DomainBed codebase.

Project Structure

Reference: lightning-hydra-template

├── configs                   <- Hydra configs
│   ├── callbacks                <- Callbacks configs
│   ├── data                     <- Data configs
│   ├── experiment               <- Experiment configs
│   ├── extras                   <- Extra utilities configs
│   ├── logger                   <- Logger configs
│   ├── model                    <- Model configs
│   ├── paths                    <- Project paths configs
│   ├── trainer                  <- Trainer configs
│   ├── eval.yaml             <- Main config for evaluation
│   └── train.yaml            <- Main config for training
│
├── logs                   <- Logs generated by hydra and lightning loggers
├── modelzoo               <- pretrained model weights
├── data                   <- downloaded datasets
├── src                    <- Source code
│   ├── datasets                <- benchmark datasets
│   ├── models                  <- model components
│   ├── lightning               <- LightningModule, DataModule, Callbacks
│   ├── solver                  <- losses, solvers, schedulers
│   ├── utils                   <- Utility modules
├── train.py                    <- Run training
├── eval.py                     <- Run evaluation
│
├── .env                      <- Example of file for storing private environment variables
├── .project-root             <- File for inferring the position of project root directory
├── requirements.yml          <- File for installing pip environment
└── README.md

Evaluation

Download pretrained checkpoints

From OneDrive and save in the folder of ./modelzoo

DomainNet

python eval.py model=vdpg_ViT_B16_CLIP.yaml paths.data_dir="./data" data=<data_name> ckpt_path=./modelzoo/<ckpt_name>

For example, we run the evaluation using DomainNet's sketch domain as the out-of-distribution.

python eval.py model=vdpg_ViT_B16_CLIP.yaml paths.data_dir="./data" data=domainnet_sketch_contrastive.yaml ckpt_path=./modelzoo/domainnet_sketch.ckpt

WILDS

iWildCam:

python eval.py model.model.num_prompts=100 paths.data_dir="./data" data=iwild_contrastive ckpt_path=./modelzoo/iWildCam.ckpt

Camelyon:

python eval.py model.model.num_prompts=5 paths.data_dir="./data" data=camelyon17_contrastive ckpt_path=./modelzoo/Camelyon.ckpt

FMoW:

python eval.py model.model.num_prompts=5 paths.data_dir="./data" data=fmow_contrastive ckpt_path=./modelzoo/FMoW.ckpt

Reproduced results using pretrained checkpoints

OOD Top1 Acc Others
iwildcam 0.7898 OOD F1-score: 0.4678
fmow 0.6235 OOD WR acc: 0.4689
camelyon 0.9604 -
DomainNet clipart 0.7643 -
DomainNet infograph 0.4923 -
DomainNet paint 0.6792 -
DomainNet quick 0.1756 -
DomainNet real 0.8182 -
DomainNet sketch 0.6681 -

Train (To be added)

If you use this code in your research, please consider citing our paper:

@inproceedings{chi_2024_ICLR,
  title={Adapting to Distribution Shift by Visual Domain Prompt Generation},
  author={Zhixiang Chi, Li Gu, Tao Zhong, Huan Liu, Yuanhao Yu, Konstantinos N Plataniotis, Yang Wang},
  booktitle={Proceedings of the Twelfth International Conference on Learning Representations},
  year={2024}
}

@inproceedings{
zhong2022metadmoe,
title={Meta-{DM}oE: Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts},
author={Tao Zhong and Zhixiang Chi and Li Gu and Yang Wang and YUANHAO YU and Jin Tang},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=_ekGcr07Dsp}
}