frgfm/torch-cam

How to get CAM for custom 3D model?

arpit1984 opened this issue ยท 19 comments

๐Ÿš€ Feature

How can we use it for multi-view CNN. I have a custom CNN which takes 2 images( different view of same object) and then concatenate the features in the end and connect them to fully connected layer.
Following is the methoid

class MVCNN(nn.Module):
    def __init__(self, num_classes=1000, pretrained=True):
        super(MVCNN, self).__init__()
        self.gradients = None
        self.tensorhook = []
        self.layerhook = []
        self.selected_out = None
        
        self.resnet = models.resnet50(pretrained = pretrained)
        
        self.layerhook.append(self.resnet.layer4.register_forward_hook(self.forward_hook()))
       
        fc_in_features = self.resnet.fc.in_features
        self.features = nn.Sequential(*list(self.resnet.children())[:-1])
        self.classifier = nn.Sequential(
            nn.Dropout(),
            ## multiplying by 2 to take care of two views
            nn.Linear(fc_in_features * 2, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes)
        )
        for p in self.resnet.parameters():
            p.requires_grad = True

    def activations_hook(self,grad):
        self.gradients = grad

    def get_act_grads(self):
        return self.gradients

    def forward_hook(self):
        def hook(module, inp, out):
            self.selected_out = out
            self.tensorhook.append(out.register_hook(self.activations_hook))
        return hook

    def forward(self, inputs): # inputs.shape = samples x views x height x width x channels
        inputs = inputs.transpose(0, 1)
        view_features = [] 
        for view_batch in inputs:
            view_batch = self.features(view_batch)
            view_batch = view_batch.view(view_batch.shape[0], view_batch.shape[1:].numel())
            view_features.append(view_batch)   
            
        concat_views = torch.cat(view_features,-1)
        #pooled_views, _ = torch.max(torch.stack(view_features), 0)
        #outputs = self.classifier(pooled_views)
        outputs = self.classifier(concat_views)
        return outputs, self.selected_out

How can I see which parts in two images contributed to the classification

Motivation & pitch

Multi-view CNN improves the accuracy when you have multiple views of a given image, instead of using just one image as input

Alternatives

No response

Additional context

No response

any update or idea on this? Infact, even is this possible. I am expecting to see different activates regions in both the views. Right now all I see same activations which i think is wrong( at least I am doing it wrong)

frgfm commented

Hey @arpit1984 ๐Ÿ‘‹

The short answer is no because class activation methods weren't designed for this. Now, to give you more details:

  • class activation methods are actually not mapping directly the influence of the input image on the final output
  • they're mapping the influence of intermediate feature maps on the final output
  • this means that if that separation isn't present on those feature maps, this design cannot work
  • side note: the batch dimension (the first one for any given module in python) is handled differently, so if you mix the axis, that can become troublesome

A few things on my mind also (helpful, I hope ๐Ÿ˜„ ):

  • in your code snippet above, your class constructor stores too much information/params (self.resnet isn't used after the constructor). I would suggest:
super(MVCNN, self).__init__()
self.gradients = None
self.tensorhook = []
self.layerhook = []
self.selected_out = None

resnet = models.resnet50(pretrained = pretrained)

self.features = nn.Sequential(*list(resnet.children())[:-1])
self.layerhook.append(self.features.layer4.register_forward_hook(self.forward_hook()))
self.classifier = nn.Sequential(
    nn.Dropout(),
    ## multiplying by 2 to take care of two views
    nn.Linear(resnet.fc.in_features * 2, 2048),
    nn.ReLU(inplace=True),
    nn.Dropout(),
    nn.Linear(2048, 2048),
    nn.ReLU(inplace=True),
    nn.Linear(2048, num_classes)
)
for p in self.features.parameters():
    p.requires_grad_(True)
for p in self.classifier.parameters():
    p.requires_grad_(True)

Let me know if that's still unclear, cheers โœŒ๏ธ

Thanks for reply .
Also, I might be wrong, but I was thinking as I have two parallel branches for each of the view to calculate the features till the last block of resnet and then I concatenated these features to fully connected layer to pass them to one classification layer to perform the prediction. I should still be able to use the features set from the last year of these two branches to calculate the grad cam for both images individually.

frgfm commented

Hey there :)
To be more specific, I'm not saying it's not possible, I'm saying as designed as is, it's not but we can work this out!
My suggestion is the following:

  • in your hooks, make sure to clone the tensor and not adding a reference to them (out.clone())
  • not sure, but considering your design, I think the comment about the inputs.shape is wrong (channel axis coming before height and width)
  • your forward method could be using batch processing (but with your commented section, I'm not sure how you leverage the view axis so I added 2 propositions):
def forward(self, inputs):
    b, v, c, h, w = inputs.shape
    # Flatten the batch and view dimension
    batch_feats = self.features(inputs.reshape(b * v, c, h, w))
    outputs = self.classifier(batch_feats)
    return outputs, self.selected_out


def forward(self, inputs):
    b, v, c, h, w = inputs.shape
    # Flatten the batch and view dimension
    batch_feats = self.features(inputs.reshape(b * v, c, h, w))
    # Pooled views
    batch_feats = batch_feats.reshape(b, v, -1).max(dim=1).values
    outputs = self.classifier(batch_feats)
    return outputs, self.selected_out
  • Once that it is done, the extracted cam at a given layer from the CNN will have the shape (B * V, H', W').
  • You only have to unroll the first axis and you'll be good to go!

Hope that helps ๐Ÿ‘

Thanks for explaining. I am still learning the "hook" and pytorch. I work with fastai and it tends to be simpler.

  • You are right, my comment about the input shape is wrong. It is indeed samples x views x channels x height x width
  • I am not using "pooled view" commented code, working with concatenating the features. If I understand it correctly the following
-# Flatten the batch and view dimension
    batch_feats = self.features(inputs.reshape(b * v, c, h, w))

instead of passing the two views to two parallel branches of CNN and concatenating the features later, you are merging the images toegther and then passing them to regular one regular branch of CNN which gives the batch_feats of shape (16, 2048, 1, 1) ( 8 bs and 2 views = 16)

Though, when I pass it to self.classifier(batch_feats), I get mismatch dimension error. So, I reshaped the batch_feats to (8,4096) to make it comparable to my concat_views shape from original forward function, but I am still getting "'tuple' object has no attribute 'view'" error.

I think I am getting some idea what you are proposing but I am not able to get hold on it. Can you please help me with that

frgfm commented

Sure!
Here are the two snippets that are replicating what I think you're trying to do:

def forward(self, inputs):
    b, v, c, h, w = inputs.shape
    # Flatten the batch and view dimension
    batch_feats = self.features(inputs.reshape(b * v, c, h, w))
    # I don't get the line below as you mix view, channels & spatial dimensions. So even if we didn't talk about it, I imagine the self.features include the average pooling which reduce the spatial dims to 1
    outputs = self.classifier(batch_feats.reshape(b, v, -1).flatten(start_dim=1))
    return outputs, self.selected_out


def forward(self, inputs):
    b, v, c, h, w = inputs.shape
    # Flatten the batch and view dimension
    batch_feats = self.features(inputs.reshape(b * v, c, h, w))
    # Pooled views (this makes more sense to me because from a CNN it's the equivalent of max pooling across channels)
    batch_feats = batch_feats.reshape(b, v, -1).max(dim=1).values
    outputs = self.classifier(batch_feats)
    return outputs, self.selected_out

In both cases, using the CAM extractor should give, at a given layer a CAM of shape (B * V, H', V').
I hope this helps ๐Ÿ‘

Thanks for replying.
I am trying to replicate the CNN model mentioned in https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0245230. I am trying to do what is shown in Fig2 b) Late fusion (concatenation and fully connected). Thats what I am doing in my forward function.
As I said, I am getting the idea of your code but still little unclear to me.
flatten(dim=1) is throwing the error that "flatten() got an unexpected keyword argument 'dim'
"

Sorry for asking dumb questions. i have spent so much time on it, that now I am missing obvious fixes.

frgfm commented

My bad, it's .flatten(start_dim=1)

I'll edit my previous answer ๐Ÿ‘ Let me know if that fixes things!

Thanks for the help. I was able to run the model. Now I am using following to calculate the CAM:

model = learn.model
from torchcam.methods import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model)

Then I get following error:

Cell In [23], line 29, in MVCNN.forward(self, inputs)
28 def forward(self, inputs):
---> 29 b, v, c, h, w = inputs.shape
30 # Flatten the batch and view dimension
31 batch_feats = self.features(inputs.reshape(b * v, c, h, w))

ValueError: not enough values to unpack (expected 5, got 4)

Any idea how to fix this

frgfm commented

Yup, specify the input_shape when calling SmoothGradCAMpp (by default it's (3,224,224))
Or even better, specify directly the layer you want to target

This worked great. Thank you very much for explaining the code too.

frgfm commented

Glad you worked it out ๐Ÿ˜

Feel free to reopen the issue if you're still having problems!

hi,
I was able to generate the GRAD CAMS for both views and for most it makes sense. Though, I noticed that in both views, it always shows something as highlighted region. I think in cases where, nothing is contributed from one view, nothing should be highlighted. But so far it seems like everything something is highlighted in both views. Do you have any idea why it might be happening?

frgfm commented

Hey @arpit1984 !
Yes, this is because of the normalization, try to pass normalized=False when you call the CAM extractor ๐Ÿ˜‰

Thanks, so I did
activation_map = cam_extractor(out.squeeze(0).argmax().item(), out,normalized=False)
and I dont see any activations, as in nothing gets hightlighted anymore.

frgfm commented

Out of the box you can only to 2 things:

  • normalized=True will normalize each map relatively to itself (each one withh have min value of 0 and max of 1)
  • normlized=False lets you do your normalization (across maps for instance)

normlized=False lets you do your normalization (across maps for instance), what does it mean by " let you do your normalization"? So, I will have to do some sort of normalization to after generating the map and before plotting them? Can you provide a sample how to do it?

frgfm commented

what does it mean by " let you do your normalization"?

I mean that you can process the raw Class Activation map however you see it ๐Ÿ˜… Not normalizing it, doing something else, whichever you prefer!

Regarding your second question, I'd have to go deeper into your problem and I don't have that much time right now unfortunately :/ If you simply plot the raw version, I expect that you'll see a difference!

sure, I understand and you have been very helpful.
I did some testing, where I trained a multi-view model where kept a same image as second view in every instance of training, with the idea that model wont learn anything from it and then there shouldnt be any activations on that view/image.
I noticed that that view still have heatmap plotted, but always at same place for class1 and some different place in class2. So, the heatmap is still drawn on it, but it is consistently same in class1 and class2. So, it could be just some artifact?