Visualizing SAM Mask decoder
nahidalam opened this issue · 0 comments
nahidalam commented
I was looking into your tutorial on visualizing self-attention DINO.
Planning to do similar for visualizing the attention heads of the mask decoders for SAM. Based on SAM paper and SAM codebase, there should be 8
attention heads. But below code shows me there are only 1
attention head
from transformers import ViTFeatureExtractor
from transformers import SamModel, SamProcessor
from transformers import SamVisionConfig, SamConfig, SamPromptEncoderConfig, SamMaskDecoderConfig
# image feature extraction
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/sam-vit-base", size = 1024)
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
# define SAM configs
vision_config = SamVisionConfig(patch_size = 16)
prompt_encoder_config = SamPromptEncoderConfig()
mask_decoder_config = SamMaskDecoderConfig()
samconfig = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
# define model
model = SamModel.from_pretrained("facebook/sam-vit-base", config = samconfig)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# forward pass
outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True, return_dict=True)
The 2nd dimension of the above tensor should be 8
not 1
.
Am I missing something?