/BrainDecodesDeepNets

PyTorch implementation of "Brain Decodes Deep Nets"

Primary LanguageJupyter Notebook

Brain Decodes Deep Nets
PyTorch Implementation

We visualize pre-trained vision models by mapping them onto the brain, thus exposing their hidden inside. Visualization is a by-product of brain encoding model: predict brain fMRI measurements in response to images.

theory

Brain Decodes Deep Nets
Huzheng Yang, James Gee*, Jianbo Shi*
University of Pennsylvania

This is a follow-up work on the Algonauts 2023 challenge winning methods Memory Encoding Model. However, we go the opposite direction: the challenge is about understanding the brain, this work is using brain data to explain deep nets.

We provide a plug-in and play API (example.ipynb) to map your vision model to the brain in 30min. The visualization is a by-product of brain encoding model: predicting brain fMRI measurements in response to images.

Methods

Brain encoding model in a nut shell:

  1. input image, extract features from pre-trained deep nets
  2. feature selection for each brain voxel (FactorTopy)
  3. linear transformation on selected feature, output each brain voxel

The intuitive understanding for our visualization is: each brain voxel asks the question, "which network layer/space/scale/channel best predicts my brain response?".

Results

Our analysis and visualization shows:

  1. Inner layer layouts of supervised and un-supervised models are different.
  2. Larger model have less efficient inner layer layout.
  3. Fine-tuning on small datasets change the layer layouts.

API: Our API (example.ipynb) is plug-in and play for most image backbone models, it only take 30 min to run (RTX4090, 8G VRAM).

layer_selectors channel


This repository contains:

Data preparation

Algonauts 2023

Please manually download and unzip subj01.zip (4GB) from the Algonauts 2023 challenge, please fill in this form to get the download link.

The provided dataset implementation expect the following contents:

<ROOT>/training_split
<ROOT>/training_split/training_fmri
<ROOT>/training_split/training_images

The dataset can be loaded with the following Python code:

from brainnet.dataset import BrainDataset

ROOT = "/data/download/alg23/subj01"
dataset = BrainDataset(ROOT)
img, fmri = dataset[0]

brain

  • Note: this folding brain animation is made by webview function from pycortex, the animation is screen recording of static website created by the cortex.webgl functions.

Installation

The brian encoding model training requires PyTorch 2.0, pytorch-lightning and torchmetrics. The implemented pre-trained models requires dinov2, open_clip, timm, segment_anything. To setup all the required dependencies for brain encoding model and pre-trained models, please follow the instructions below:

conda - Clone the repository and then create and activate a brainnet conda environment using the provided environment definition:

conda env create -f conda_env.yml 
conda activate brainnet

docker - Pull from docker hub, the docker image (15G) contain all the packages.

docker pull huzeeee/afo:latest

Brain-to-Network Mapping

The brain-to-network mapping is a by-product of brain encoding model, plmodel.py is a pytorch lightning implementation of the training code. The following code is a minimal example to run the training:

from brainnet.plmodel import PLModel
from brainnet.config import get_cfg_defaults
from brainnet.backbone import ModifiedCLIP
import pytorch_lightning as pl

cfg = get_cfg_defaults()
cfg.DATASET.DATA_DIR = '/data/huze/download/alg23/subj01'

backbone = ModifiedCLIP()
plmodel = PLModel(cfg, backbone)

trainer = pl.Trainer()
trainer.fit(plmodel)

API: Our API (example.ipynb) is plug-in and play for most image backbone models, it only take 30 min to run (RTX4090, 8G VRAM).

mapping

API for analyzing your model

We have implemented 6 models in brainnet.backbone, the backbones are supported to be plug-in and play in the brain encoding model. To plug-in your model, you need to implement a model.get_tokens() method as the following code block. During training, the brain encoding model (in plmodel.forward) will call this get_tokens() to get the local and global tokens from your pre-trained model.

A complete API example is in example.ipynb

class ModifiedDiNOv2(nn.Module):
    def __init__(self, ver="dinov2_vitb14", **kwargs) -> None:
        super().__init__()
        self.vision_model = torch.hub.load("facebookresearch/dinov2", ver)

    """
    get_tokens() need to be implemented by user.
    """
    def get_tokens(self, x):
        x = self.vision_model.prepare_tokens_with_masks(x)

        local_tokens = {}
        global_tokens = {}
        for i, blk in enumerate(self.vision_model.blocks):
            x = blk(x)
            saved_x = x.clone()
            global_tokens[str(i)] = saved_x[:, 0, :]  # [B, C]
            saved_x = saved_x[:, 1:, :]  # remove cls token, [B, N, C]
            p = int(np.sqrt(saved_x.shape[1]))
            saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
            local_tokens[str(i)] = saved_x
        
        """
        return:
        local_tokens: dict, key: str, value: [B, C, H, W]
        global_tokens: dict, key: str, value: [B, C]
        """
        return local_tokens, global_tokens

Automatic caching

Caching offer speed up by trading memory consumption. PLModel(cached=True) will enable automatic caching of local_tokans and global_tokens returned by the image backbone model. The tokens will be computed only once and stay in the RAM (not VRAM), a 12-layer 768-dim model takes ~20G of RAM for caching. No cache will stored to the hard disk, cache is stored in RAM and deleted after running.


Acknowledgement

We thank the Nature Scene Dataset team and Algonauts 2023 organizers for providing the data. Compute resource is provided by Penn GRASP Laboratory and Penn Image Computing & Science Laboratory.

License

The code and model weights are licensed under CC-BY-NC. See LICENSE.txt for details.

BibTeX

@article{yang_brain_2023,
  title={Brain Decodes Deep Nets},
  author={Yang, Huzheng and Gee, James and Shi, Jianbo},
  year={2023},
  journal={arXiv preprint arXiv:2312.01280},
}