/2024-MambaVC

Code for MambaVC: Learned Visual Compression with Selective State Spaces

Primary LanguagePython

  • MambaVC: Learned Visual Compression with Selective State Spaces

This is the Pytorch repository of the paper "MambaVC: Learned Visual Compression with Selective State Spaces".

Please feel free to contact Shiyu Qin (qinsy23@mails.tsinghua.edu.cn) if you have any questions.

Abstract

Learned visual compression is an important and active task in multimedia. Existing approaches have explored various CNN- and Transformer-based designs to model content distribution and eliminate redundancy, where balancing efficacy (i.e., rate-distortion trade-off) and efficiency remains a challenge. Recently, state-space models (SSMs) have shown promise due to their long-range modeling capacity and efficiency. Inspired by this, we take the first step to explore SSMs for visual compression. We introduce MambaVC, a simple, strong and efficient compression network based on SSM. MambaVC develops a visual state space (VSS) block with a 2D selective scanning (2DSS) module as the nonlinear activation function after each downsampling, which helps to capture informative global contexts and enhances compression. On compression benchmark datasets, MambaVC achieves superior rate-distortion performance with lower computational and memory overheads. Specifically, it outperforms CNN and Transformer variants by 9.3% and 15.6% on Kodak, respectively, while reducing computation by 42% and 24%, and saving 12% and 71% of memory. MambaVC shows even greater improvements with high-resolution images, highlighting its potential and scalability in real-world applications. We also provide a comprehensive comparison of different network designs, underscoring MambaVC's advantages.

Architectures

The overall framework.

Evaluation Results

RD curves on Kodak (trained on flickr30k).

Installation

This codebase was tested with the following environment configurations. It may work with other versions.

  • Ubuntu 20.04
  • CUDA 12.2
  • Python 3.8
  • PyTorch 2.2.0 + cu121

From CompressAI:

git clone https://github.com/InterDigitalInc/CompressAI compressai
cd compressai
pip install -U pip && pip install -e .

From Vmamba:

git clone https://github.com/MzeroMiko/VMamba.git
cd VMamba
cd kernels/selective_scan && pip install .

Dataset

The dataset directory is expected to be organized as below:

dataset_root/
  • train/
    • train_1.jpg
    • train_2.jpg
    • ...
  • test/
    • test_1.jpg
    • test_2.jpg
    • ...

Training

CUDA_VISIBLE_DEVICES=0 python3 train.py --cuda -d <dataset_root> \
    -n 128 --lambda 0.05 --epochs 500 --lr_epoch 450 490 --batch-size 8 \
    --save_path <ckpt_to_path> --save \
    --checkpoint <resumed_ckpt_path> --continue_train

Testing

CUDA_VISIBLE_DEVICES=0 python3 eval.py --cuda --data <dataset_root> --checkpoint <pretrained_ckpt_path>

Notes

We use calflops to calculate MACs, FLOPs, and model parameters.

Citation

@article{qin2024mambavc,
  title={MambaVC: Learned Visual Compression with Selective State Spaces},
  author={Qin, Shiyu and Wang, Jinpeng and Zhou, Yiming and Chen, Bin and Luo, Tianci and An, Baoyi and Dai, Tao and Xia, Shutao and Wang, Yaowei},
  journal={arXiv preprint arXiv:2405.15413},
  year={2024}
}

Ackownledgement

Our code is based on the implementation of CompressAI/Mamba/Vmamba. We thank the authors for open-sourcing their code.