Result wrong
sjf18 opened this issue · 0 comments
sjf18 commented
Hi , thanks for your great work, i write a demo using your model to predict images, but it seems something wrong with the result, like this: why is refined output gray?
here is my demo code, could you please help me?
model_path = './model_logs/offical/latest_ckpt.pth.tar'
nets = torch.load(model_path)
netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets['netD_state_dict']
netG = InpaintSANet()
load_consistent_state_dict(netG_state_dict, netG)
netG.to(cpu0)
netG.eval()
torch.set_grad_enabled(False)
save_img_dir = 'results/'
os.makedirs(save_img_dir, exist_ok=True)
test_img_dir = 'testdata/'
imgs_list = os.listdir(test_img_dir)
input_shape = (256,256)
for imgname in tqdm(imgs_list):
img = cv2.imread(test_img_dir + imgname)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, c = img_rgb.shape
img_resize = cv2.resize(img_rgb, input_shape)
mask = random_ff_mask(input_shape)
img_tensor = torch.from_numpy((img_resize.astype(np.float32)[np.newaxis, :, :, :])).permute(0, 3, 1, 2)
mask_tensor = torch.from_numpy((mask.astype(np.float32)[np.newaxis, :, :, :])).permute(0, 3, 1, 2)
used_img, used_mask = img_tensor.to(cpu0), mask_tensor.to(cpu0)
used_img = (used_img / 127.5 - 1)
corse_img, refine_img = netG(used_img, used_mask)
## network output
cor_img = 127.5*(corse_img+1).permute(0, 2, 3, 1)
ref_img = 127.5*(refine_img+1).permute(0, 2, 3, 1)
cor_img_np = cor_img.data.numpy()[0]
ref_img_np = ref_img.data.numpy()[0]
## complete output
cor_complete_img = corse_img * used_mask + used_img * (1 - used_mask)
ref_complete_img = refine_img * used_mask + used_img * (1 - used_mask)
cor_complete_img = 127.5*(cor_complete_img+1).permute(0, 2, 3, 1)
ref_complete_img = 127.5*(ref_complete_img+1).permute(0, 2, 3, 1)
cor_complete_img_np = cor_complete_img.data.numpy()[0]
ref_complete_img_np = ref_complete_img.data.numpy()[0]
## save images
first = np.concatenate((img_resize, 255*np.concatenate((mask,)*3, -1)), 0)
third = np.concatenate((ref_complete_img_np, cor_complete_img_np), 0)
second = np.concatenate((ref_img_np, cor_img_np), 0)
out_img = np.concatenate((first, second, third), 1)
out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(save_img_dir + imgname, out_img)