Ask for Model Visualization Code
YangYangGirl opened this issue · 3 comments
YangYangGirl commented
Thank you for your great work. Can you share your visual code? I'd like to do some analysis based on your model :)
YJHMITWEB commented
Hi, thanks for your interest. The visualization code is in the form of cmd lines and is mostly based on opencv (heatmap). In order to perform the visualization, you can store the tensors you want manually and load them using numpy.
xiao2mo commented
update here. can u provide the visualization code ? Thanks a lot.
VictorLlu commented
import cv2
import numpy as np
attn_matrix = torch.matmul(torch.matmul(kernel_1_, self.newton_inv(kernel_2_)), kernel_3_) # batch, head, n, n
attn_index = h // 2 * w // 2 # take the center of feature map as reference, you can change the index at your will
attn = attn_matrix[..., attn_index, :] # batch, head, n
attn = attn.reshape(b, nhead, h, w) # batch, head, h, w
for batch_idx in range(b):
for head_idx in range(nhead):
attn_draw = attn[batch_idx, head_idx] # h, w
attn_draw = (attn_draw - attn_draw.min()) / (attn_draw.max() - attn_draw.min())
attn_draw = attn_draw * 255.
attn_draw = attn_draw.detach().cpu().numpy().astype(np.uint8)
attn_draw = cv2.applyColorMap(attn_draw, cv2.COLORMAP_JET)
cv2.imwrite(save_path, attn_draw)