fudan-zvg/SOFT

Ask for Model Visualization Code

YangYangGirl opened this issue · 3 comments

Thank you for your great work. Can you share your visual code? I'd like to do some analysis based on your model :)

Hi, thanks for your interest. The visualization code is in the form of cmd lines and is mostly based on opencv (heatmap). In order to perform the visualization, you can store the tensors you want manually and load them using numpy.

update here. can u provide the visualization code ? Thanks a lot.

import cv2
import numpy as np
attn_matrix = torch.matmul(torch.matmul(kernel_1_, self.newton_inv(kernel_2_)), kernel_3_) # batch, head, n, n
attn_index = h // 2 * w // 2 # take the center of feature map as reference, you can change the index at your will
attn = attn_matrix[..., attn_index, :] # batch, head, n
attn = attn.reshape(b, nhead, h, w) # batch, head, h, w 
for batch_idx in range(b):
	for head_idx in range(nhead):
		attn_draw = attn[batch_idx, head_idx] # h, w 
		attn_draw = (attn_draw - attn_draw.min()) / (attn_draw.max() - attn_draw.min())
		attn_draw = attn_draw * 255.
		attn_draw = attn_draw.detach().cpu().numpy().astype(np.uint8)
		attn_draw = cv2.applyColorMap(attn_draw, cv2.COLORMAP_JET)
		cv2.imwrite(save_path, attn_draw)