Add support of Vision Transformer
Yung-zi opened this issue ยท 6 comments
๐ Feature
I am appreciated for your great job! However, I have a question. Can Layer-CAM be used with Vision Transformer Network? If it does work, what aspects should I change?
Motivation & pitch
I'm working on the job related to CAM.
Alternatives
No response
Additional context
No response
Hello @Yung-zi ๐
My apologies, I've been busy with other projects lately!
As of right now, the library is designed to work with CNNs. However, the way it was designed basically only relies on forward activation and backpropagated gradient hooks. So to answer your question, I'd need to run some tests but if the output activation of a given layer is of shape (N, C, H, W), whatever the way it was computed as long as this doesn't break the backprop (i.e. being differentiable), the library should work without much (perhaps any) change ๐
Either way, I intend on spending more time on Vision transformers compatibility for the next release ๐
If you're interested in helping / or providing feedback once it's in progress, let me know!
Hello @Yung-zi ๐
My apologies, I've been busy with other projects lately! As of right now, the library is designed to work with CNNs. However, the way it was designed basically only relies on forward activation and backpropagated gradient hooks. So to answer your question, I'd need to run some tests but if the output activation of a given layer is of shape (N, C, H, W), whatever the way it was computed as long as this doesn't break the backprop (i.e. being differentiable), the library should work without much (perhaps any) change ๐
Either way, I intend on spending more time on Vision transformers compatibility for the next release ๐ If you're interested in helping / or providing feedback once it's in progress, let me know!
I am so sorry for late reply. I tried to change your code before. However, the effect looked not well maybe I made some mistakes. Have you ever made it on Vision transformer?
Partially yes!
But I have staged this for the next release anyway so I'll dive into it to make it available :)
Quick update!
As of today, here is the support status of Torchvision transformer architectures:
- maxvit
- swin
- swin_v2
- vit (so far I can't see a way to make this integration seamless, because of the concatenation on the channel dimension and the dimension swapping)
Another update: VIT requires another method called Attention flow!
I'll try to investigate & implement this but this is a bit more complex than just inverting the axis swap & slicing.
Your excellent work has helped me a lot! Thank you for this! However, I have a question. I downloaded torchcam0.4.0 and had good visualization results on the CNN models. But it didn't work on the Vit model. Here's what happened: Since I was working offline, I downloaded the ViT weight file and loaded the model using timm. The result was blue pixels covering the entire image, i.e. no heatmap area was found. What do I need to change in the code to make it work? Or as you mentioned above, are you still working on it? Thank you for taking time out of your busy schedule.