microsoft/FocalNet

Visualization of gated outputs

Tajamul21 opened this issue · 2 comments

RuntimeError Traceback (most recent call last)
/tmp/ipykernel_63586/1379561863.py in
23 fig.add_subplot(1, 5, i+2)
24 gates_i = (upsampler(gates[:, i:i+1])).cpu().detach()
---> 25 plt.imshow(gates_i.permute(1,2,0).numpy())
26 plt.axis('off')
27 x.axes.get_xaxis().set_visible(False)

RuntimeError: number of dims don't match in permute

visualize gating maps

upsampler = nn.Upsample(scale_factor=4, mode='bilinear')

img_folder = "/home/tajamul/scratch/FocalNet/FocalNet/demo_fig/"
img_paths = os.listdir(img_folder)
for img_path in img_paths:
img = Image.open(img_folder + img_path)
img_t = eval_transforms(img)
img_d = display_transforms(img)
out = model(img_t.unsqueeze(0).cuda())

fig=plt.figure(figsize=(16, 8))

fig.add_subplot(1, 5, 1)       
img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()
x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())     
plt.axis('off')
x.axes.get_xaxis().set_visible(False)
x.axes.get_yaxis().set_visible(False)    

gates = (model.layers[-1].blocks[-1].modulation.gates)
for i in range(4):
    fig.add_subplot(1, 5, i+2)        
    gates_i = (upsampler(gates[:, i:i+1])).cpu().detach()
    plt.imshow((gates_i.squeeze(0)).squeeze(0).numpy())
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)
    
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

This updated code is working. Thanks