/GenPromp

[ICCV 2023] Generative Prompt Model for Weakly Supervised Object Localization

Primary LanguagePythonApache License 2.0Apache-2.0

Generative Prompt Model for Weakly Supervised Object Localization

This is the official implementaion of paper Generative Prompt Model for Weakly Supervised Object Localization, which is accepted in ICCV 2023. This repository contains Pytorch training code, evaluation code, pre-trained models, and visualization method.

arXiv preprint Python 3.8 PyTorch 1.11 LICENSE

PWC PWC

1. Contents

2. Introduction

Weakly supervised object localization (WSOL) remains challenging when learning object localization models from image category labels. Conventional methods that discriminatively train activation models ignore representative yet less discriminative object parts. In this study, we propose a generative prompt model (GenPromp), defining the first generative pipeline to localize less discriminative object parts by formulating WSOL as a conditional image denoising procedure. During training, GenPromp converts image category labels to learnable prompt embeddings which are fed to a generative model to conditionally recover the input image with noise and learn representative embeddings. During inference, GenPromp combines the representative embeddings with discriminative embeddings (queried from an off-the-shelf vision-language model) for both representative and discriminative capacity. The combined embeddings are finally used to generate multi-scale high-quality attention maps, which facilitate localizing full object extent. Experiments on CUB-200-2011 and ILSVRC show that GenPromp respectively outperforms the best discriminative models, setting a solid baseline for WSOL with the generative model.

3. Results

We re-train GenPromp with a better learning schedule on 4 x A100. The performance of GenPromp on CUB-200-2011 is further improved.

Method Dataset Cls Back. Top-1 Loc Top-5 Loc GT-known Loc
GenPromp CUB-200-2011 EfficientNet-B7 87.0 96.1 98.0
GenPromp (Re-train) CUB-200-2011 EfficientNet-B7 87.2 (+0.2) 96.3 (+0.2) 98.3 (+0.3)
GenPromp ImageNet EfficientNet-B7 65.2 73.4 75.0

4. Get Start

4.1 Installation

To setup the environment of GenPromp, we use conda to manage our dependencies. Our developers use CUDA 11.3 to do experiments. Run the following commands to install GenPromp:

conda create -n gpm python=3.8 -y && conda activate gpm
pip install --upgrade pip
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --upgrade diffusers[torch]==0.13.1
pip install transformers==4.29.2 accelerate==0.19.0
pip install matplotlib opencv-python OmegaConf tqdm

4.2 Dataset and Files Preparation

To train GenPromp with pre-training weights and infer GenPromp with the given weights, download the files in the table and arrange the files according to the file tree below. (Uploading)

Dataset & Files Download Usage
data/ImageNet_ILSVRC2012 (146GB) Official Link Benchmark dataset
data/CUB_200_2011 (1.2GB) Official Link Benchmark dataset
ckpts/pretrains (5.2GB) Official Link, Google Drive, Baidu Drive(o9ei) Stable Diffusion pretrain weights
ckpts/classifications (2.3GB) Google Drive, Baidu Drive(o9ei) Classfication results on benchmark datasets
ckpts/imagenet750 (3.3.GB) Google Drive, Baidu Drive(o9ei) Weights that achieves 75.0% GT-Known Loc on ImageNet
ckpts/cub983 (3.3GB) Google Drive, Baidu Drive(o9ei) Weights that achieves 98.3% GT-Known Loc on CUB
    |--GenPromp/
      |--data/
        |--ImageNet_ILSVRC2012/
           |--ILSVRC2012_list/
           |--train/
           |--val/
        |--CUB_200_2011
           |--attributes/
           |--images/
           ...
      |--ckpts/
        |--pretrains/
          |--stable-diffusion-v1-4/
        |--classifications/
          |--cub_efficientnetb7.json
          |--imagenet_efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k.json
        |--imagenet750/
          |--tokens/
             |--49408.bin
             |--49409.bin
             ...
          |--unet/
        |--cub983/
          |--tokens/
             |--49408.bin
             |--49409.bin
             ...
          |--unet/
      |--configs/
      |--datasets
      |--models
      |--main.py

4.3 Training

Here is a training example of GenPromp on ImageNet.

accelerate config
accelerate launch python main.py --function train_token --config configs/imagenet.yml --opt "{'train': {'save_path': 'ckpts/imagenet/'}}"
accelerate launch python main.py --function train_unet --config configs/imagenet_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/imagenet/tokens/', 'save_path': 'ckpts/imagenet/'}}"

accelerate is used for multi-GPU training. In the first training stage, the weights of concept tokens of the representative embeddings are learned and saved to ckpts/imagenet/. In the second training stage, the weights of the learned concept tokens are loaded from ckpts/imagenet/tokens/, then the weights of the UNet are finetuned and saved to ckpts/imagenet/. Other configurations can be seen in the config files (i.e. configs/imagenet.yml and configs/imagenet_stage2.yml) and can be modified by --opt with a parameter dict (See Extra Options for details).

Here is a training example of GenPromp on CUB_200_2011.

accelerate config
accelerate launch python main.py --function train_token --config configs/cub.yml --opt "{'train': {'save_path': 'ckpts/cub/'}}"
accelerate launch python main.py --function train_unet --config configs/cub_stage2.yml --opt "{'train': {'load_token_path': 'ckpts/cub/tokens/', 'save_path': 'ckpts/cub/'}}"

4.4 Inference

Here is a inference example of GenPromp on ImageNet.

python main.py --function test --config configs/imagenet_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path': 'ckpts/imagnet750/log.txt'}}"

In the inference stage, the weights of the learned concept tokens are load from ckpts/imagenet750/tokens/ , the weights of the finetuned UNet are load from ckpts/imagenet750/unet/ and the log file is saved to ckpts/imagnet750/log.txt. Due the random noise added to the tested image, the results might fluctuate within a small range ($\pm$ 0.1).

Here is a inference example of GenPromp on CUB_200_2011.

python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}"

4.5 Extra Options

There are many extra options during training and inference. The default option is configured in the yml file. We can use --opt to add or override the default option with a parameter dict. Here are some usage of the most commonly used options.

Option Scope Usage
{'data': {'keep_class': [0, 9]}} data keep the data with category id in [0, 1, 2, 3, ..., 9]
{'train': {'batch_size': 2}} train train with batch size 2.
{'train': {'num_train_epochs': 1}} train train the model for 1 epoch.
{'train': {'save_steps': 200}} train_unet save trained UNet every 200 steps.
{'train': {'max_train_steps': 600}} train_unet terminate training within 600 steps.
{'train': {'gradient_accumulation_steps': 2}} train batch size x2 when the memory of GPU is limited.
{'train': {'learning_rate': 5.0e-08}} train the learning rate is 5.0e-8.
{'train': {'scale_lr': True}} train the learning rate is multiplied with batch size if True.
{'train': {'load_pretrain_path': 'stable-diffusion/'}} train the pretrained model is load from stable-diffusion/.
{'train': {'load_token_path': 'ckpt/tokens/'}} train the trained concept tokens are load from ckpt/tokens/.
{'train': {'save_path': 'ckpt/'}} train save the trained weights to ckpt/.
{'test': {'batch_size': 2}} test test with batch size 2.
{'test': {'cam_thr': 0.25}} test test with cam threshold 0.25.
{'test': {'combine_ratio': 0.6}} test combine ratio between $f_r$ and $f_d$ is 0.6.
{'test': {'load_class_path': 'imagenet_efficientnet.json'}} test load classification results from imagenet_efficientnet.json.
{'test': {'load_pretrain_path': 'stable-diffusion/'}} test the pretrained model is load from stable-diffusion/.
{'test': {'load_token_path': 'ckpt/tokens/'}} test the trained concept tokens are load from ckpt/tokens/.
{'test': {'load_unet_path': 'ckpt/unet/'}} test the trained UNet is load from ckpt/unet/.
{'test': {'save_vis_path': 'ckpt/vis/'}} test the visualized predictions are saved to ckpt/vis/.
{'test': {'save_log_path': 'ckpt/log.txt'}} test the log file is saved to ckpt/log.txt.
{'test': {'eval_mode': 'top1'}} test top1 denotes evaluating the predicted top1 cls category of the test image, top5 denotes evaluating the predicted top5 cls category of the test image, gtk denotes evaluating the gt category of the test image, which can be tested without the classification result. We use top1 as the default eval mode.

These options can be combined by simplely merging the dicts. For example, if you want to evaluate GenPromp with config file configs/imagenet_stage2.yml, with categories [0, 1, 2, ..., 9], concept tokens load from ckpts/imagenet750/tokens/, UNet load from ckpts/imagenet750/unet/, log file of the evaluated metrics saved to ckpts/imagnet750/log0-9.txt, combine ratio equals to 0, visualization results saved to ckpts/imagenet750/vis, using the following command:

python main.py --function test --config configs/imagenet_stage2.yml --opt "{'data': {'keep_class': [0, 9]}, 'test': {'load_token_path': 'ckpts/imagenet750/tokens/', 'load_unet_path': 'ckpts/imagenet750/unet/', 'save_log_path':'ckpts/imagnet750/log.txt', 'combine_ratio': 0, 'save_vis_path': 'ckpts/imagenet750/vis'}}"

5. Contacts

If you have any question about our work or this repository, please don't hesitate to contact us by emails or open an issue under this project.

6. Acknowledgment

7. Citation

@article{zhao2023generative,
  title={Generative Prompt Model for Weakly Supervised Object Localization},
  author={Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
  journal={arXiv preprint arXiv:2307.09756},
  year={2023}
}
@InProceedings{Zhao_2023_ICCV,
    author    = {Zhao, Yuzhong and Ye, Qixiang and Wu, Weijia and Shen, Chunhua and Wan, Fang},
    title     = {Generative Prompt Model for Weakly Supervised Object Localization},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {6351-6361}
}