how to get hidden_state from every layers of ViT of sam vision encoder?
jzssz opened this issue · 1 comments
jzssz commented
thanks a lot
heyoeyo commented
The simplest way is probably to use pytorch's forward hooks functionality to grab the output of the Block
modules (assuming that's what you mean by the hidden state of the layers in this case).
You can do something like:
from segment_anything.modeling.image_encoder import Block
# ... assuming SamPredictor & image data are already set up ...
# Use forward hooks to store 'Block' outputs when encoding image
captures = []
hook_func = lambda m, inp, out: captures.append(out)
for m in predictor.model.modules():
if isinstance(m, Block):
m.register_forward_hook(hook_func)
with torch.no_grad(): predictor.set_image(image)
print("Number of blocks captured:", len(captures))
You can change the import ... Block
part to other model components to grab the corresponding outputs (like the neck or attention outputs).