/audio-mamba-official

Official implementation for our paper "Audio Mamba: Selective State Spaces for Self-Supervised Audio Representations"

Primary LanguagePythonMIT LicenseMIT

audio-mamba-official

This is the official repository for our paper "Audio Mamba: Selective State Spaces for Self-Supervised Audio Representations", set to appear in Proc. INTERSPEECH 2024.

Contents


Setup

Environment

  • Required: cuda 11.x or newer, cudnn 8.2 or newer.
  • Create a new conda environment with python 3.10 or later.
  • Requires torch 2.1.2 or newer.

Follow these steps

conda create -n mamba-env python=3.10 -y
conda activate mamba-env

pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118

pip install -r requirements.txt

# install hear-eval-kit specific requirements
pip install -r external_sources/hear-eval-kit/requirements.txt

# install hear-eval-kit, WITHOUT AUTO DEPS
cd external_sources/hear-eval-kit && pip install --no-deps . && cd -

# install causal-conv1d
pip install git+https://github.com/Dao-AILab/causal-conv1d.git@v1.1.3.post1

# install mamba-ssm
pip install git+https://github.com/state-spaces/mamba.git@v1.1.3.post1

Get 16000 Hz data from hear

  • Follow https://hearbenchmark.com/hear-tasks.html to get data. By default, data on HEAR's zenodo page is 48000 Hz.
  • We recommend downloading data directly from HEAR's GCS bucket, where you can find preprocessed 16000 Hz data.
  • Extract all the files to a folder $TASKS_DIR

Get pretrained weights

  • Pre-trained weights can be downloaded from Google Drive
  • Download the entire folder and export that folder as $PT_MAMBA_MODEL_DIR

Extract features

export PT_MAMBA_MODEL_DIR=/path/to/pretrained_weights
./extract_features.sh $TASKS_DIR $OUTPUT_DIR

where TASKS_DIR is the directory where you extracted tasks from HEAR-2021 to, and OUTPUT_DIR is the base directory where output features will be stored. The given script will extract features from SSAST and SSAM Tiny configurations, you can change it as you need. This also prepares a todo_audioset directory in OUTPUT_DIR, which is setting up for downstream classification on 10 seeds.

Run downstream experiments

After extracting features, to run downstream experiment on a specific config, use the following command:

./downstream_experiments.sh ssam_tiny_200_16x4 $OUTPUT_DIR/todo_audioset

This will run downstream experiments on all the extracted features for the tiny SSAM configuration on 10 random seeds.

Get results

Finally, you can run the following script to get results of downstream experiments of the two models

python stats_aggregation_v2.py --base_dir ${OUTPUT_DIR}/todo_audioset --output_dir ${OUTPUT_DIR}/parsed_results

Extracting features on your own audio file

The hear_api can be used to extract features from your own audio files.

import torchaudio

from hear_api import RuntimeSSAST
from importlib import import_module
config = import_module("configs.ssam_tiny_200_16x4").get_config()
ssam = RuntimeSSAST(config, "path/to/pretrained_dir").cuda()

# alternatively just use the following if you have the paths setup right
# ssam = import_module("hear_configs.ssam_tiny_200_16x4").load_model().cuda()

x, sr = torchaudio.load("path/to/audio.wav")
x = x.cuda()
o = ssam.get_scene_embeddings(x)

Pretraining

Pretraining code is included in the release. Any model configuration (for instance, ssam_tiny_200_16x4) was trained with the following command:

torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py --config configs.ssam_tiny_200_16x4 --workdir $EXP_DIR/ssam_tiny_200_16x4_4x256_fp16_r1 --precision float16 --print_freq 50 --num_workers 16 --no_wandb

We use a torchdata based datapipe for data loading, operating on precomputed log melspectrogram features stored in webdataset archive(s). You can adapt the data loading for your own use case.