facebookresearch/segment-anything

how to get hidden_state from every layers of ViT of sam vision encoder?

jzssz opened this issue · 1 comments

jzssz commented

thanks a lot

The simplest way is probably to use pytorch's forward hooks functionality to grab the output of the Block modules (assuming that's what you mean by the hidden state of the layers in this case).

You can do something like:

from segment_anything.modeling.image_encoder import Block

# ... assuming SamPredictor & image data are already set up ...

# Use forward hooks to store 'Block' outputs when encoding image
captures = []
hook_func = lambda m, inp, out: captures.append(out)
for m in predictor.model.modules():
    if isinstance(m, Block):
        m.register_forward_hook(hook_func)
with torch.no_grad(): predictor.set_image(image)

print("Number of blocks captured:", len(captures))

You can change the import ... Block part to other model components to grab the corresponding outputs (like the neck or attention outputs).