/MHS-VM

Pytorch implementation of "MHS-VM: Multi-Head Scanning in Parallel Subspaces for Vision Mamba"

Primary LanguagePythonApache License 2.0Apache-2.0

This is the official code repository for MHS-VM.

MHS-VM

Multi-Head Scanning in Parallel Subspaces for Vision Mamba

(https://arxiv.org/pdf/2406.05992)

Release

  • News2024/07/05: MHS Module and MHS-VM released.

Introduction

A Multi-Head Scan (MHS) mechanism is introduced to enhance visual representation learning.

module

A richer array of scan patterns is introduced to capture the diverse visual patterns present in vision data.

Scan Patterns

A Scan Route Attention (SRA) mechanism is introduced to enable the model to attenuate or screen out trivial features, thereby enhancing its ability to capture complex structures in images.

Embedding Section Fusion

In our experiments, we examine the CV for the relative deviations of the $k$ values, providing insights into the variability and consistency of the embeddings' responses along different scan routes. We facilitate the module's ability to selectively filter or attenuate information through the incorporation of a multiplicative gating mechanism based on the relative CV. This process is formulated as:

$$ \begin{equation} z = (\sum_{i=1}^{k} y_i) \odot \sigma(y_{cv}) \end{equation} $$

where $y_{cv} = \text{std}([y_i]) / \text{avg}([y_i-\text{min}([y_i])])$ represents the relative CV, and $\odot$ denotes the element-wise product between tensors, and $\sigma(x)$ is a monotone function, such as Sigmoid, ReLU, power function and exponential function $\exp(\cdot)$, etc. This monotone function is introduced to prompt the Mamba block to extract position-aware features.

$$ \begin{equation} \sigma(x, t) = \text{ReLU}(x-t) = \text{max}(0, x-t) \end{equation} $$

This function returns $0$ when $x < t$ and $x-t$ otherwise. The parameter $t$ can be set as a hyperparameter or a learnable parameter. Such a strategy can be considered as a novel regularization technique to prevent over-fitting and improve generalization.

Main Environments

The environment installation can follow the work VM-UNet, or the steps below:

conda create -n mhsvm python=3.10
conda activate mhsvm
pip install torch==2.0.1 torchvision==0.15.2
pip install packaging==24.0
pip install timm==1.0.3
pip install triton==2.0.0
pip install causal_conv1d==1.2.0 
pip install mamba_ssm==1.2.0
pip install tensorboardX  
pip install pytest chardet yacs termcolor
pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs

Datesets

For datasets, please refer to VM-UNet for further details.

Scan Route Dictionary

Since the scan routes are fixed within the model, we opt to pre-generate the route hierarchy and store it in a dictionary. To accommodate various resolutions, you can generate the scan routes using the following command:

python routegen.py --w 512 --h 512

Train

cd MHS-VM
python train.py

Test

cd MHS-VM
python test.py --h 4 --d isic2018 --p best_4h.pth
miou: 0.8085252327081669, f1_or_dsc: 0.8941265712919525

An interesting observation is that the model, which was trained using the dataset isic2018, might yield notably high performance when evaluated against the test set of the dataset isic2017.

cd MHS-VM
python test.py --h 4 --d isic2017 --p best_4h.pth
miou: 0.8201665691022297, f1_or_dsc: 0.9011994649553033

Citation

If you find this repository useful, please cite our work:

@misc{ji2024mhsvmmultiheadscanningparallel,
      title={MHS-VM: Multi-Head Scanning in Parallel Subspaces for Vision Mamba}, 
      author={Zhongping Ji},
      year={2024},
      eprint={2406.05992},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2406.05992}, 
}

Acknowledgments

This code is based the VM-UNet.