Can it work with a single image?
aligoglos opened this issue · 4 comments
aligoglos commented
I wrote simple code to run model on a single image but result is gray still !!
minimal demo :
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import argparse
import subprocess
import utils
import glob
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
refimgs = None
disable_colorization = False
# Load remaster network
modelR = __import__( 'model.remasternet', fromlist=['NetworkR'] ).NetworkR()
state_dict = torch.load( 'remasternet.pth' )
modelR.load_state_dict( state_dict['modelR'] )
modelR = modelR.to(device)
modelR.eval()
if not disable_colorization:
modelC = __import__( 'model.remasternet', fromlist=['NetworkC'] ).NetworkC()
modelC.load_state_dict( state_dict['modelC'] )
modelC = modelC.to(device)
modelC.eval()
paths = sorted(glob.glob('./inputs' + '/*'))
for path in paths:
image = cv2.imread(path)
if ~(image is None):
name = path.split('\\')[-1]
print(name)
refimgs = cv2.imread(F"./references/{name}")
with torch.no_grad():
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
frame_l = torch.from_numpy(gray).view( gray.shape[0], gray.shape[1], 1 )
frame_l = frame_l.permute(2, 0, 1).float() # HWC to CHW
frame_l /= 255.
frame_l = frame_l.view(1, frame_l.size(0), 1, frame_l.size(1), frame_l.size(2))
input = frame_l.to( device )
output_l = modelR( input )
if refimgs is None:
output_ab = modelC( output_l )
else:
refimgs = torch.from_numpy(refimgs)
refimgs = refimgs.permute(2, 0, 1).float().unsqueeze(axis = 0).unsqueeze(axis = 0)
refimgs /= 255.
refimgs = refimgs.to( device )
output_ab = modelC( output_l, refimgs )
output_l = output_l.detach().cpu()
output_ab = output_ab.detach().cpu()
out_l = output_l[0,:,0,:,:]
out_c = output_ab[0,:,0,:,:]
output = torch.cat((out_l, out_c), dim=0).numpy().transpose((1, 2, 0))
output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) )
output.save( F"./results/{name}" )
if __name__ == "__main__":
main()
** Note : reference image is equal to input
zhaoyuzhi commented
I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?
hermosayhl commented
I encounter this issue, too. Is there anyone make it ?
hermosayhl commented
I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?
Dr zhao, have you overcome this problem?
Dawars commented
You have to emulate multiple frames by duplicating the image to make the temporal convolutions work:
input = torch.tile(input, (1, 1, 5, 1, 1))
The network still isn't able to use colors from the reference images if they are significantly different from the gray image.