facebookresearch/segment-anything

Extracted embedding result does not make sense

benam2 opened this issue · 4 comments

Hi,
I'm trying to extract embedding fromSAM model (first part only). To gauge the quality of these embeddings, I saved them in a data frame and ranked the top 5 most similar. I had previously annotated a sample of 50 data points for evaluating the model. Surprisingly, the performance is quite poor even though the images are not particularly challenging. I wondered perhaps I'm doing something wrong?!

Appreciate any input

here is the main part of my code that extract embedding:

     `response = requests.get(row['URL'], stream=True)
  if response.status_code == 200:
      image_bytes = BytesIO(response.content)
      image_array = np.asarray(bytearray(image_bytes.read()), dtype=np.uint8)
      img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)

      predictor.set_image(img)
      with torch.inference_mode():  # torch.no_grad()
        feature = predictor.features
        tensor_flattened = feature.view(-1)[:256]`

I guess it depends on what part of the embedding you want to work with. One change I would recommend is moving the .set_image part under the inference_mode block, since that will stop the gradients from being kept on the image features.

Aside from that, the features you're indexing seem strange (to me at least). I think it's equivalent to:
feature[0, 0, :4, :] <--> feature.view(-1)[:256]

The output of SAM should be a 64x64 grid of 'pixels', each with 256 channels. So the indexing here would be interpreted as something like: 'take the first channel value of all (64) columns of the first 4 rows'. If you're trying to compare the first few rows specifically, then this is probably fine, but if you meant to check something else then maybe this is the problem?

Thank you @heyoeyo for the detailed explanation, appreciate it. You are right, the extracted part does not make sense. I want to get the Image representation from the encoder. However, I needed to keep the dimension low (below 1000). Probably stupid question but any idea how to do that?

I'm benchmarking the DinoV2 image representation Vs SAM, and I have already annotated the data to see which one's performance is higher

It could be tricky, since the SAM model doesn't include a global/cls token. The simplest (probably bad) idea I can think of would be to average the tokens spatially, into a single 256 length vector. No idea if that does anything useful to be honest, but it's easy to try at least.

The dinov2 repo/paper actually has this idea of using PCA to extract meaningful patch features from the model output and maybe something similar (i.e. reducing the SAM features using PCA) could work for what you're trying? There's a notebook on the dinov2 repo (under a pull request that you'd have to grab) that might help with how to implement something similar.

And keeping with the dinov2 theme, they have a paper "Vision Transformers Need Registers" that suggests high-norm tokens end up in the patches (for models without registers) and encode global information. It's sort of a hack idea, but you might be able to pull high norm tokens off the SAM features and use them as if they are good global encodings of the image (though SAM doesn't have the same high-norm token problem as dinov2 from what I've seen).

I think the 'formal' approach is usually to train a linear classifier on top of the models you're testing, and see how well the model + classifier works on your data, but obviously that's a pain if it's hard to label examples.

Thanks so much for the suggestion. I will look into VIT-Reg