NielsRogge/Transformers-Tutorials

Visualizing SAM Mask decoder

nahidalam opened this issue · 0 comments

I was looking into your tutorial on visualizing self-attention DINO.

Planning to do similar for visualizing the attention heads of the mask decoders for SAM. Based on SAM paper and SAM codebase, there should be 8 attention heads. But below code shows me there are only 1 attention head

from transformers import ViTFeatureExtractor
from transformers import SamModel, SamProcessor
from transformers import SamVisionConfig, SamConfig, SamPromptEncoderConfig, SamMaskDecoderConfig

# image feature extraction
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/sam-vit-base", size = 1024)
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values

# define SAM configs
vision_config = SamVisionConfig(patch_size = 16)
prompt_encoder_config = SamPromptEncoderConfig()
mask_decoder_config = SamMaskDecoderConfig()
samconfig = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)

# define model
model = SamModel.from_pretrained("facebook/sam-vit-base", config = samconfig)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# forward pass
outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True, return_dict=True)
Screenshot 2024-01-25 at 9 17 51 PM

The 2nd dimension of the above tensor should be 8 not 1.

Am I missing something?