/mind-vis

Code base for MinD-Vis

Primary LanguagePythonMIT LicenseMIT

Seeing Beyond the Brain: Masked Modeling Conditioned Diffusion Model for Human Vision Decoding

MinD-Vis

MinD-Vis is a framework for decoding human visual stimuli from brain recording. This document introduces the precesedures required for replicating the results in Seeing Beyond the Brain: Masked Modeling Conditioned Diffusion Model for Human Vision Decoding

Abstract

Decoding visual stimuli from brain recordings aims to deepen our understanding of the human visual system and build a solid foundation for bridging human and computer vision through the Brain-Computer Interface. However, due to the scarcity of data annotations and the complexity of underlying brain information, it is challenging to decode images with faithful details and meaningful semantics. In this work, we present MinD-Vis: Sparse Masked Brain Modeling with Double-Conditioned Latent Diffusion Model for Human Vision Decoding. Specifically, by boosting the information capacity of feature representations learned from a large-scale resting-state fMRI dataset, we show that our MinD-Vis can reconstruct highly plausible images with semantically matching details from brain recordings with very few paired annotations. We benchmarked our model qualitatively and quantitatively; the experimental results indicate that our method outperformed state-of-the-art in both semantic mapping (100-way semantic classification) and generation quality (FID) by 66% and 41% respectively.

Overview

flowchar-img Our framework consists of two main stages:

  • Stage A: Sparse-Coded Masked Brain Modeling (SC-MBM)
  • Stage B: Double-Conditioned Latent Diffusion Model (DC-LDM)

The data folder and pretrains folder are not included in this repository. Please download them from FigShare and put them in the root directory of this repository as shown below.

File path | Description


/data
┣ πŸ“‚ HCP
┃   ┣ πŸ“‚ npz
┃   ┃   ┣ πŸ“‚ dummy_sub_01
┃   ┃   ┃   β”— HCP_visual_voxel.npz
┃   ┃   ┣ πŸ“‚ dummy_sub_02
┃   ┃   ┃   β”— ...

┣ πŸ“‚ Kamitani
┃   ┣ πŸ“‚ npz
┃   ┃   β”— πŸ“œ sbj_1.npz
┃   ┃   β”— πŸ“œ sbj_2.npz
┃   ┃   β”— πŸ“œ sbj_3.npz
┃   ┃   β”— πŸ“œ sbj_4.npz
┃   ┃   β”— πŸ“œ sbj_5.npz
┃   ┃   β”— πŸ“œ images_256.npz
┃   ┃   β”— πŸ“œ imagenet_class_index.json
┃   ┃   β”— πŸ“œ imagenet_training_label.csv
┃   ┃   β”— πŸ“œ imagenet_testing_label.csv

┣ πŸ“‚ BOLD5000
┃   ┣ πŸ“‚ BOLD5000_GLMsingle_ROI_betas
┃   ┃   ┣ πŸ“‚ py
┃   ┃   ┃   β”— CSI1_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_LHEarlyVis.npy
┃   ┃   ┃   β”— ...
┃   ┃   ┃   β”— CSIx_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_xx.npy
┃   ┣ πŸ“‚ BOLD5000_Stimuli
┃   ┃   ┣ πŸ“‚ Image_Labels
┃   ┃   ┣ πŸ“‚ Scene_Stimuli
┃   ┃   ┣ πŸ“‚ Stimuli_Presentation_Lists


/pretrains
┣ πŸ“‚ ldm
┃   ┣ πŸ“‚ label2img  (ImageNet pre-trained label-conditioned LDM)
┃   ┃   β”— πŸ“œ config.yaml
┃   ┃   β”— πŸ“œ model.ckpt

┣ πŸ“‚ GOD  
┃   β”— πŸ“œ fmri_encoder.pth  (SC-MBM pre-trained fMRI encoder)
┃   β”— πŸ“œ finetuned.pth     (finetuned fMRI encoder + finetuned LDM)

┣ πŸ“‚ BOLD5000
┃   β”— πŸ“œ fmri_encoder.pth  (SC-MBM pre-trained fMRI encoder)
┃   β”— πŸ“œ finetuned.pth     (finetuned fMRI encoder + finetuned LDM)


/code
┣ πŸ“‚ sc_mbm
┃   β”— πŸ“œ mae_for_fmri.py
┃   β”— πŸ“œ trainer.py
┃   β”— πŸ“œ utils.py

┣ πŸ“‚ dc_ldm
┃   β”— πŸ“œ ldm_for_fmri.py
┃   β”— πŸ“œ utils.py
┃   ┣ πŸ“‚ models
┃   ┃   β”— (adopted from LDM)
┃   ┣ πŸ“‚ modules
┃   ┃   β”— (adopted from LDM)

β”—  πŸ“œ stageA1_mbm_pretrain.py   (main script for pre-training for SC-MBM)
β”—  πŸ“œ stageA2_mbm_finetune.py   (main script for tuning SC-MBM on fMRI only from test sets)
β”—  πŸ“œ stageB_ldm_finetune.py    (main script for fine-tuning DC-LDM)
β”—  πŸ“œ gen_eval.py               (main script for generating decoded images)

β”—  πŸ“œ dataset.py                (functions for loading datasets)
β”—  πŸ“œ eval_metrics.py           (functions for evaluation metrics)
β”—  πŸ“œ config.py                 (configurations for the main scripts)

Environment setup

Create and activate conda environment named mind-vis from our env.yaml

conda env create -f env.yaml
conda activate mind-vis

Download data and checkpoints

Due to size limi and license issue, the full fMRI pre-training dataset (required to replicate Stage A) needs to be downloaded from the Human Connectome Projects (HCP) offical website. The pre-processing scripts are also included in this repo.

We also provide checkpoints and finetuning data at FigShare to run the finetuing and decoding directly. Due to the size limit, we only release the checkpoints for Subject 3 and CSI4 in the GOD and BOLD5000 respectively. Checkpoints for other subjects are also available upon request. After downloading, extract the data/ and pretrains/ to the project directory.

SC-MBM Pre-training on fMRI (Stage A)

mbm-fig The fMRI pre-training is performed with masked brain modeling in the fMRI dataset containing around 136,000 fMRI samples from 1205 subjects (HCP + GOD). To perform the pre-training from scratch with defaults parameters, run

python code/stageA1_mbm_pretrain.py

Hyper-parameters can be changed with command line arguments,

python code/stageA1_mbm_pretrain.py --mask_ratio 0.65 --num_epoch 800 --batch_size 200

Or the parameters can also be changed in code/config.py

Multiple-GPU (DDP) training is supported, run with

python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_mbm_pretrain.py

The pre-training results will be saved locally at results/fmri_pretrain and remotely at wandb.

After pre-training on the large-scale fMRI dataset, we need to finetune the autoencoder with fMRI data from the testing set. Run the following,

python code/stageA2_mbm_finetune.py --dataset GOD --pretrain_mbm_path results/fmri_pretrain/RUN_FOLDER_NAME/checkpoints/checkpoint.pth

--dataset can be either GOD or BOLD5000. And RUN_FOLDER_NAME is the folder name generated for the pre-training. For example

python code/stageA2_mbm_finetune.py --dataset GOD --pretrain_mbm_path results/fmri_pretrain/01-08-2022-11:37:22/checkpoints/checkpoint.pth

The fMRI finetuning results will be saved locally at results/fmri_finetune and remotely at wandb.

Finetune the Double-Conditional LDM with Pre-trained fMRI Encoder (Stage B)

In this stage, the cross-attention heads and pre-trained fMRI encoder will be jointly optimized with fMRI-image pairs. Decoded images will be generated in this stage. This stage can be run without downloading HCP. Only finetuning datasets and pre-trained fMRI encoder shared in our FigShare link are required. Run this stage with our provided pre-trained fMRI encoder and default parameters:

python code/stageB_ldm_finetune.py --dataset GOD

--dataset can be either GOD or BOLD5000. The results and generated samples will be saved locally at results/generation and remotely at wandb.

Run with custom-pre-trained fMRI encoder and parameters:

python code/stageB_ldm_finetune.py --dataset GOD --pretrain_mbm_path results/fmri_fintune/RUN_FOLDER_NAME/checkpoints/checkpoint.pth --num_epoch 500 --batch_size 5

Run fMRI Decoding and Generate Images with Trained Checkpoints

Only finetuning datasets and trained checkpoints in our FigShare link are required. Notice that images generated by the provided checkpoins gives the same evaluation reuslts as in the paper, but may not produce the exact same images as in the paper due to sampling variance. Run this stage with our provided checkpoints:

python code/gen_eval.py --dataset GOD

--dataset can be either GOD or BOLD5000. The results and generated samples will be saved locally at results/eval and remotely at wandb.

bold5000

Acknowledgement

We thank Kamitani Lab, Weizmann Vision Lab and BOLD5000 team for making their raw and pre-processed data public. Our Masked Brain Modeling implementation is based on the Masked Autoencoders by Facebook Research. Our Conditional Latent Diffusion Model implementation is based on the Latent Diffusion Model implementation from CompVis. We thank these authors for making their codes and checkpoints publicly available!