/PEM

PEM: Prototype-based Efficient MaskFormer for Image Segmentation

Primary LanguagePython

PEM: Prototype-based Efficient MaskFormer for Image Segmentation (CVPR 2024)

Niccolò Cavagnero*, Gabriele Rosi*, Claudia Cuttano, Francesca Pistilli, Marco Ciccone, Giuseppe Averta, Fabio Cermelli

* Equal Contribution

[Project Page] [Paper]

This is the official PyTorch implementation of our work "PEM: Prototype-based Efficient MaskFormer for Image Segmentation" accepted at CVPR 2024.


Prototype-based Efficient MaskFormer (PEM) is an efficient transformer-based architecture that can operate in multiple segmentation tasks. PEM proposes a novel prototype-based cross-attention which leverages the redundancy of visual features to restrict the computation and improve the efficiency without harming the performance.

architecture

Table of Contents

Installation

The code has been tested with python>=3.8 and pytorch==1.12.0. To prepare the conda environment please run the following:

conda create --name pem python=3.10 -y
conda activate pem

conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

git clone https://github.com/NiccoloCavagnero/PEM.git
cd PEM
pip install -r requirements.txt

Data preparation

For the dataset preparation, plese refer to the Mask2Former guide.

Training

  1. Before starting the training, you have to download the pretrained models for the backbone. The following commands will download the pretrained weights for STDC1 and STDC2 backbones (read more about here). For ResNet50, the pretrained weights are automatically downloaded from detectron2 repository.

    mkdir pretrained_models
    cd pretrained_models
    gdown 1DFoXcV42zy-apUcMh5P8WhsXMRJofgl8
    gdown 1Y5belNkq3Dn-EYgSKY-ICiPsN4TZXoXO
    python ../tools/convert-pretrained-stdc-model-to-d2.py STDCNet813M_73.91.tar STDC1.pkl
    python ../tools/convert-pretrained-stdc-model-to-d2.py STDCNet1446_76.47.tar STDC2.pkl
    cd ..
  2. To train the model with train_net.py, run the following

    python train_net.py --num-gpus 4 \
      --config-file configs/cityscapes/semantic-segmentation/pem_R50_bs32_90k.yaml 

Testing

To test the model, you can use train_net.py with the flag --eval-only along with the checkpoint path of the trained model.

python train_net.py --eval-only \
  --config-file configs/cityscapes/semantic-segmentation/pem_R50_bs32_90k.yaml \
  MODEL.WEIGHTS /path/to/checkpoint_file

Results

Panoptic segmentation

cityscapes panoptic
Table 1. Panoptic segmentation on Cityscapes with 19 categories.

ade20k panoptic
Table 2. Panoptic segmentation on ADE20K with 150 categories.

Semantic segmentation

cityscapes semantic
Table 3. Semantic segmentation on Cityscapes with 19 categories.

ade20k semantic
Table 4. Semantic segmentation on ADE20K with 150 categories.

Citation

If you find this project helpful for your research, please consider citing the following BibTeX entry.

@article{cavagnero2024pem,
  title={PEM: Prototype-based Efficient MaskFormer for Image Segmentation},
  author={Cavagnero, Niccol{\`o} and Rosi, Gabriele and Cuttano, Claudia and 
  Pistilli, Francesca and Ciccone, Marco and Averta, Giuseppe and Cermelli, Fabio},
  journal={arXiv preprint arXiv:2402.19422},
  year={2024}
}

Acknowledgement

The code is largely based on Mask2Former whom we thank for their excellent work.