bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets

Visualizing attention blocks

ajay-bhargava opened this issue · 1 comments

Hi,

Great job on implementing the models from Ozan Oktay. I am, unfortunately, having a bit of difficulty in interpreting how you're visualizing the attention gates (intermediate kernels/layers) in the Attention Unet model. Could you please provide some documentation on your implementation?

In particular, you lose me here:

class LayerActivations():
    """Getting the hooks on each layer"""

    features = None

    def __init__(self, layer):
        self.hook = layer.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.features = output.cpu()

    def remove(self):
        self.hook.remove()

which is referenced here:

    x1 = torch.nn.ModuleList(model_test.children())
    # x2 = torch.nn.ModuleList(x1[16].children())
     #x3 = torch.nn.ModuleList(x2[0].children())

    #To get filters in the layers
     #plot_kernels(x1.weight.detach().cpu(), 7)

    #####################################
    # for images
    #####################################
    x2 = len(x1)
    dr = LayerActivations(x1[x2-1]) #Getting the last Conv Layer

    img = Image.open(test_image)
    s_tb = data_transform(img)

    pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu()
    pred_tb = F.sigmoid(pred_tb)
    pred_tb = pred_tb.detach().numpy()

    plot_kernels(dr.features, n_iter, 7, cmap="rainbow")

How is this code grabbing the Attention-Gates in the interior of the model?

bigmb commented

@ajay-bhargava Hey, So are you able to visualize any layers? (This just calculates the output of the gradient and we then visualize the same.)
The layer I have mentioned in the code is the output of the last Conv layer so it will show the activated area of the image.
(I have never tried to visualize any other layer tbh so i need to check up on that. )

Did you check the layers in the model and select the attention block layer and try to put that layer in the visualization?
Check this one out: https://www.kaggle.com/sironghuang/understanding-pytorch-hooks.

Let me know if you need some help. I will try to look into this too.