/DIaM

Official PyTorch Implementation of DIaM in "A Strong Baseline for Generalized Few-Shot Semantic Segmentation" (CVPR 2023)

Primary LanguagePythonMIT LicenseMIT

PWC PWC PWC PWC

DIaM for Generalized Few-Shot Semantic Segmentation

This repository contains the code for our CVPR 2023 paper, A Strong Baseline for Generalized Few-Shot Semantic Segmentation.

Abstract: This paper introduces a generalized few-shot segmentation framework with a straightforward training process and an easy-to-optimize inference phase. In particular, we propose a simple yet effective model based on the well-known InfoMax principle, where the Mutual Information (MI) between the learned feature representations and their corresponding predictions is maximized. In addition, the terms derived from our MI-based formulation are coupled with a knowledge distillation term to retain the knowledge on base classes. With a simple training process, our inference model can be applied on top of any segmentation network trained on base classes. The proposed inference yields substantial improvements on the popular few-shot segmentation benchmarks PASCAL-5i and COCO-20i. Particularly, for novel classes, the improvement gains range from 7% to 26% (PASCAL-5i) and from 3% to 12% (COCO-20i) in the 1-shot and 5-shot scenarios, respectively. Furthermore, we propose a more challenging setting, where performance gaps are further exacerbated.

๐ŸŽฌ Getting Started

1๏ธโƒฃ Requirements

We used Python 3.9 in our experiments and the list of packages is available in the requirements.txt file. You can install them using pip install -r requirements.txt.

2๏ธโƒฃ Download data

Pre-processed data from drive

We provide the versions of PASCAL VOC 2012 and MS-COCO 2017 used in this work here. You can download the full .zip and directly extract it in the data/ folder.

From scratch

Alternatively, you can prepare the datasets yourself. Here is the structure of the data folder for you to reproduce:

data
โ”œโ”€โ”€ coco
โ”‚   โ”œโ”€โ”€ annotations
โ”‚   โ”œโ”€โ”€ train
โ”‚   โ”œโ”€โ”€ train2014
โ”‚   โ”œโ”€โ”€ val
โ”‚   โ””โ”€โ”€ val2014
โ””โ”€โ”€ pascal
|   โ”œโ”€โ”€ JPEGImages
|   โ””โ”€โ”€ SegmentationClassAug

PASCAL: The JPEG images can be found in the PASCAL-VOC 2012 toolkit to be downloaded at PASCAL VOC 2012 and SegmentationClassAug (pre-processed ground-truth masks).

COCO: COCO 2014 train images, validation images and annotations can be downloaded at COCO. Once this is done, you will have to generate the subfolders coco/train and coco/val (ground truth masks). Both folders can be generated by executing the python script data/coco/create_masks.py (note that this script uses the pycocotools package):

cd data/coco
python create_masks.py

About the train/val splits

The train/val splits are directly provided in lists/. How they were obtained is explained at https://github.com/Jia-Research-Lab/PFENet.

3๏ธโƒฃ Download pre-trained models

Pre-trained backbone and models

We provide the pre-trained backbone and models at https://drive.google.com/file/d/1WuKaJbj3Y3QMq4yw_Tyec-KyTchjSVUG/view?usp=share_link. You can download them and directly extract them at the root of this repo. This will create two folders: initmodel/ and model_ckpt/.

๐Ÿ—บ Overview of the repo

Default configuration files can be found in config/. Data are located in data/. lists/ contains the train/val splits for each dataset. All the codes are provided in src/. Testing script is located at the root of the repo.

โš™ Training (optional)

If you want to use the pre-trained models, this step is optional. Our contribution lies in the inference phase and our approach is modular, i.e., it can be applied on top of any segmentation model that is trained on the base classes. We use a simple training scheme by minimizing a standard cross-entropy over base classes. To this end, we have used the train_base.py script and base learner models of BAM (see this issue for more info).

๐Ÿงช Testing

To test the model, use the test.sh script, which its general syntax is:

bash test.sh {benchmark} {shot} {pi_estimation_strategy} {[gpu_ids]} {log_path}

This script tests successively on all folds of the benchmark and reports the results individually. The overall performance is the average over all the folds. Some example commands are presented below, with their description in the comments.

bash test.sh pascal5i 1 self [0] out.log  # PASCAL-5i benchmark, 1-shot, estimate pi by model's output
bash test.sh pascal10i 5 self [0] out.log  # PASCAL-10i benchmark, 5-shot, estimate pi by model's output
bash test.sh coco20i 5 upperbound [0] out.log  # COCO-20i benchmark, 5-shot, the upperbound model mentioned in the paper

If you run out of memory, reduce batch_size_val in the config files.

๐Ÿ“Š Results

To reproduce the results, please first download the pre-trained models from here (also mentioned in the "download pre-trained models" section) and then run the test.sh script with different inputs, as explained above.

1-Shot 5-Shot
Benchmark Fold Base Novel Mean Base Novel Mean
PASCAL-5i 0 71.33 29.36 50.35 71.06 53.72 62.39
1 69.54 46.72 58.13 69.63 63.33 66.48
2 69.10 27.07 48.09 69.12 54.01 61.57
3 73.60 37.30 55.45 73.60 50.19 61.90
mean 70.89 35.11 53.00 70.85 55.31 63.08
COCO-20i 0 49.01 15.89 32.45 48.90 24.86 36.88
1 46.83 19.50 33.17 47.10 33.94 40.52
2 48.82 16.93 32.88 49.12 27.15 38.14
3 48.45 16.57 32.51 48.37 28.95 38.66
mean 48.28 17.22 32.75 48.37 28.73 38.55
PASCAL-10i 0 68.69 34.40 51.55 68.49 55.94 62.22
1 71.83 28.17 50.00 72.00 47.84 59.92
mean 70.26 31.29 50.77 70.25 51.89 61.07

๐Ÿ™ Acknowledgments

We gratefully thank the authors of RePRI, BAM, PFENet, and PyTorch Semantic Segmentation from which some parts of our code are inspired.

๐Ÿ“š Citation

If you find this project useful, please consider citing:

@inproceedings{hajimiri2023diam,
  title={A Strong Baseline for Generalized Few-Shot Semantic Segmentation},
  author={Hajimiri, Sina and Boudiaf, Malik and Ben Ayed, Ismail and Dolz, Jose},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={11269--11278},
  year={2023}
}