Our arxiv version is currently available. Please check it out! 🔥🔥🔥
This repository contains the official PyTorch implementation for ClusterFormer: Clustering As A Universal Visual Learner. Our work is built upon mmclassification and other repos under OpenMMLab framework. We thank the great work of them.
This paper presents ClusterFormer, a universal vision model that is based on
the CLUSTERing paradigm with TransFORMER. It comprises two novel designs: 1.
recurrent cross-attention clustering, which reformulates the cross-attention mechanism in Transformer and enables recursive updates of cluster centers to facilitate
strong representation learning; and 2. feature dispatching, which uses the updated
cluster centers to redistribute image features through similarity-based metrics, resulting in a transparent pipeline. This elegant design streamlines an explainable
and transferable workflow, capable of tackling heterogeneous vision tasks (
Figure 1: (a) Overall pipeline of ClusterFormer. (b) Each Recurrent Cross-Attention Clustering layer carries out T iterations of cross-attention clustering (E-step) and center updating (M-step) (see Eq. 3). (c) The feature dispatching redistributes the feature embeddings on the top of updated cluster centers (see Eq. 6).
Below are quick steps for installation:
conda create -n open-mmlab python=3.8 pytorch=1.10 cudatoolkit=11.3 torchvision==0.11.0 -c pytorch -y
conda activate open-mmlab
pip3 install openmim
mim install mmcv-full
cd clusterformer
pip3 install -e .
Please refer to install.md for more detailed installation and dataset preparation.
We followed the common usage of the mmclassification and check mmclassification for more training information.
In particular, we use the slurm system to train our model. Slurm is a good job scheduling system for computing clusters.
On a cluster managed by Slurm, you can use slurm_train.sh
to spawn training jobs. It supports both single-node and multi-node training.
The basic usage is as follows.
OMP_NUM_THREADS=1 [GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}
When using Slurm, the port option need to be set in one of the following ways:
-
Set the port through
--options
. This is more recommended since it does not change the original configs.OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR} --options 'dist_params.port=29500' OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR} --options 'dist_params.port=29501'
-
Modify the config files to set different communication ports.
In
config1.py
, setdist_params = dict(backend='nccl', port=29500)
In
config2.py
, setdist_params = dict(backend='nccl', port=29501)
Then you can launch two jobs with
config1.py
andconfig2.py
.OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR} OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
Note that:
- The configs are made for 8-GPU training. To train on another number of GPUs, change the
GPUS
. - If you want to measure the inference time, please change the number of gpu to 1 for inference.
- We set
OMP_NUM_THREADS=1
by default, which achieves the best speed on our machines, please change it as needed.
If you find our work helpful in your research, please cite it as:
@inproceedings{liang2023clusterformer,
title={ClusterFormer: Clustering As A Universal Visual Learner},
author={Liang, James C and Cui, Yiming and Wang, Qifan and Geng, Tong and Wang, Wenguan and Liu, Dongfang},
booktitle={Neural Information Processing Systems (NeurIPS)},
year={2023}
}