satoshiiizuka/siggraphasia2019_remastering

Can it work with a single image?

aligoglos opened this issue · 4 comments

I wrote simple code to run model on a single image but result is gray still !!
minimal demo :

 import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import argparse
import subprocess
import utils
import glob


def main():
	device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
	refimgs = None
	disable_colorization = False
	# Load remaster network
	modelR = __import__( 'model.remasternet', fromlist=['NetworkR'] ).NetworkR()
	state_dict = torch.load( 'remasternet.pth' )
	modelR.load_state_dict( state_dict['modelR'] )
	modelR = modelR.to(device)
	modelR.eval()
	if not disable_colorization:
		modelC = __import__( 'model.remasternet', fromlist=['NetworkC'] ).NetworkC()
		modelC.load_state_dict( state_dict['modelC'] )
		modelC = modelC.to(device)
		modelC.eval()
	paths = sorted(glob.glob('./inputs' + '/*'))
	for path in paths:
		image = cv2.imread(path)
		if ~(image is None):
			name = path.split('\\')[-1]
			print(name)
			refimgs = cv2.imread(F"./references/{name}")
			with torch.no_grad():
				gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
				frame_l = torch.from_numpy(gray).view( gray.shape[0], gray.shape[1], 1 )
				frame_l = frame_l.permute(2, 0, 1).float() # HWC to CHW
				frame_l /= 255.
				frame_l = frame_l.view(1, frame_l.size(0), 1, frame_l.size(1), frame_l.size(2))
				input = frame_l.to( device )
				output_l = modelR( input )
				if refimgs is None:
					output_ab = modelC( output_l )
				else:
					refimgs = torch.from_numpy(refimgs)
					refimgs = refimgs.permute(2, 0, 1).float().unsqueeze(axis = 0).unsqueeze(axis = 0)
					refimgs /= 255.
					refimgs = refimgs.to( device )
					output_ab = modelC( output_l, refimgs )

				output_l = output_l.detach().cpu()
				output_ab = output_ab.detach().cpu()
				out_l = output_l[0,:,0,:,:]
				out_c = output_ab[0,:,0,:,:]
				output = torch.cat((out_l, out_c), dim=0).numpy().transpose((1, 2, 0))
				output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) )
				output.save( F"./results/{name}" )

if __name__ == "__main__":
	main()

input image :
1

out put :
1

** Note : reference image is equal to input

I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?

I encounter this issue, too. Is there anyone make it ?

I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?

Dr zhao, have you overcome this problem?

You have to emulate multiple frames by duplicating the image to make the temporal convolutions work:

input = torch.tile(input, (1, 1, 5, 1, 1))

The network still isn't able to use colors from the reference images if they are significantly different from the gray image.