This repository contains code for implementing Grad-CAM (Gradient-weighted Class Activation Mapping) in PyTorch. Grad-CAM is a technique for visualizing the regions of an image that are important for predicting a particular class. It works by leveraging the gradients of a target class with respect to the ViT Norm layer.
- Python 3.x
- PyTorch
- Torchvision
- Timm
- NumPy
- OpenCV
- Scikit-image
- Clone this repository:
git clone https://github.com/Mikael17125/ViT-GradCAM.git cd ViT-GradCAM
-
Create the conda environment
conda create --name vit-grad-cam python=3.10
-
Activate the environment
conda activate vit-grad-cam
-
Install the PyTorch
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
-
Install additional dependency
pip install -r requirements.txt
-
Ensure that your image is named
both.png
and placed in the root directory of the repository. -
Run the
main.py
script:python main.py
-
The resulting heatmap with Grad-CAM overlay will be saved as
result.jpg
in the same directory.
main.py
: Python script to generate Grad-CAM visualization for a given image using a pre-trained Vision Transformer model (vit_base_patch16_224
).gradcam.py
: Python module containing theGradCam
class, which implements the Grad-CAM algorithm.both.png
: Sample input image (replace with your own image).result.jpg
: Output Grad-CAM visualization.
Input Image | Output Image |
---|---|
- Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, and Dhruv Batra. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." In Proceedings of the IEEE International Conference on Computer Vision (ICCV), 2017.