Visualization of gated outputs
Tajamul21 opened this issue · 2 comments
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_63586/1379561863.py in
23 fig.add_subplot(1, 5, i+2)
24 gates_i = (upsampler(gates[:, i:i+1])).cpu().detach()
---> 25 plt.imshow(gates_i.permute(1,2,0).numpy())
26 plt.axis('off')
27 x.axes.get_xaxis().set_visible(False)
RuntimeError: number of dims don't match in permute
visualize gating maps
upsampler = nn.Upsample(scale_factor=4, mode='bilinear')
img_folder = "/home/tajamul/scratch/FocalNet/FocalNet/demo_fig/"
img_paths = os.listdir(img_folder)
for img_path in img_paths:
img = Image.open(img_folder + img_path)
img_t = eval_transforms(img)
img_d = display_transforms(img)
out = model(img_t.unsqueeze(0).cuda())
fig=plt.figure(figsize=(16, 8))
fig.add_subplot(1, 5, 1)
img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()
x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())
plt.axis('off')
x.axes.get_xaxis().set_visible(False)
x.axes.get_yaxis().set_visible(False)
gates = (model.layers[-1].blocks[-1].modulation.gates)
for i in range(4):
fig.add_subplot(1, 5, i+2)
gates_i = (upsampler(gates[:, i:i+1])).cpu().detach()
plt.imshow((gates_i.squeeze(0)).squeeze(0).numpy())
plt.axis('off')
x.axes.get_xaxis().set_visible(False)
x.axes.get_yaxis().set_visible(False)
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
This updated code is working. Thanks