AndreaCodegoni/Tiny_model_4_CD

How to generate the attention mask and save them

nononoyou opened this issue · 1 comments

hello! I have read your paper, I am interested in this. But I can not find the code about generate the attention mask (res 256\128\64), so I want to How to generate the attention mask and save them?
Thanks!

Hello,

I am happy to know that our work interests you.
To print and save the attention masks generated by the MAMBs you can modify the forward method as shown below.
Be careful to initialize an auxiliary counter in the network class (self.counter=0) which in this case I use and update to give an incremental name to all the masks that the code generates and saves so as not to overwrite them.

I close the issue, but for anything else do not hesitate to comment again below, open other issues or contact me directly by email.

Have a good time :)

 def forward(self, ref: Tensor, test: Tensor) -> Tensor:
        features = self._encode(ref, test)
        #### plot intermediate features #####
        plt.figure(figsize=(10, 10))
        plt.subplot(1,3,1)
        plt.imshow(features[0][0,0,:,:].to('cpu'),cmap='jet')
        plt.xticks([])
        plt.yticks([])
        plt.title('Mask at resolution 256')
        plt.subplot(1,3,2)
        plt.imshow(features[1][0,0,:,:].to('cpu'),cmap='jet')
        plt.xticks([])
        plt.yticks([])
        plt.title('Mask at resolution 128')
        plt.subplot(1,3,3)
        plt.imshow(features[2][0,0,:,:].to('cpu'),cmap='jet')
        plt.xticks([])
        plt.yticks([])
        plt.title('Mask at resolution 64')
        plt.savefig(join("YOUR DESIRED PATH HERE","intermediate_mask_{}.png".format(self.counter)),bbox_inches='tight')
        plt.close()
        self.counter +=1
        #########
        latents = self._decode(features)
        return self._classify(latents)