Niccolò Cavagnero*, Gabriele Rosi*, Claudia Cuttano, Francesca Pistilli, Marco Ciccone, Giuseppe Averta, Fabio Cermelli
* Equal Contribution
[Project Page] [Paper]
This is the official PyTorch implementation of our work "PEM: Prototype-based Efficient MaskFormer for Image Segmentation" accepted at CVPR 2024.
Prototype-based Efficient MaskFormer (PEM) is an efficient transformer-based architecture that can operate in multiple segmentation tasks. PEM proposes a novel prototype-based cross-attention which leverages the redundancy of visual features to restrict the computation and improve the efficiency without harming the performance.
The code has been tested with python>=3.8
and pytorch==1.12.0
. To prepare the conda environment please run the following:
conda create --name pem python=3.10 -y
conda activate pem
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
git clone https://github.com/NiccoloCavagnero/PEM.git
cd PEM
pip install -r requirements.txt
For the dataset preparation, plese refer to the Mask2Former guide.
-
Before starting the training, you have to download the pretrained models for the backbone. The following commands will download the pretrained weights for STDC1 and STDC2 backbones (read more about here). For ResNet50, the pretrained weights are automatically downloaded from detectron2 repository.
mkdir pretrained_models cd pretrained_models gdown 1DFoXcV42zy-apUcMh5P8WhsXMRJofgl8 gdown 1Y5belNkq3Dn-EYgSKY-ICiPsN4TZXoXO python ../tools/convert-pretrained-stdc-model-to-d2.py STDCNet813M_73.91.tar STDC1.pkl python ../tools/convert-pretrained-stdc-model-to-d2.py STDCNet1446_76.47.tar STDC2.pkl cd ..
-
To train the model with
train_net.py
, run the followingpython train_net.py --num-gpus 4 \ --config-file configs/cityscapes/semantic-segmentation/pem_R50_bs32_90k.yaml
To test the model, you can use train_net.py
with the flag --eval-only
along with the checkpoint path of the trained model.
python train_net.py --eval-only \
--config-file configs/cityscapes/semantic-segmentation/pem_R50_bs32_90k.yaml \
MODEL.WEIGHTS /path/to/checkpoint_file
Table 1. Panoptic segmentation on Cityscapes with 19 categories.
Table 2. Panoptic segmentation on ADE20K with 150 categories.
Table 3. Semantic segmentation on Cityscapes with 19 categories.
Table 4. Semantic segmentation on ADE20K with 150 categories.
If you find this project helpful for your research, please consider citing the following BibTeX entry.
@article{cavagnero2024pem,
title={PEM: Prototype-based Efficient MaskFormer for Image Segmentation},
author={Cavagnero, Niccol{\`o} and Rosi, Gabriele and Cuttano, Claudia and
Pistilli, Francesca and Ciccone, Marco and Averta, Giuseppe and Cermelli, Fabio},
journal={arXiv preprint arXiv:2402.19422},
year={2024}
}
The code is largely based on Mask2Former whom we thank for their excellent work.