samuelyu2002/ImVisible

what is input of model?

Closed this issue · 0 comments

import torch
import torch.nn as nn
from LYTNet import LYTNet
from LYTNetV2 import LYTNetV2

from torch.utils.data import DataLoader
from dataset import TrafficLightDataset

MODEL_PATH = './LytNetV1_weights'
device = torch.device('cpu')
model=LYTNet()
model.load_state_dict(torch.load(MODEL_PATH,map_location=device))
model.eval()

test_file_loc = './traffic/testing_file.csv'
test_image_directory = './traffic/PTL_Dataset_768x576'

import numpy as np
from PIL import Image
size=(768,576)
im = Image.open('./traffic/PTL_Dataset_768x576/john_IMG_0671.jpg' )
#im = pilimg.open('./traffic/PTL_Dataset_768x576/heon_IMG_0776.jpg' )

im=im.resize(size)
im.show()

pix = np.array(im)
pix=torch.Tensor(pix).type(torch.FloatTensor)
#print(pix.shape)

pix=pix.unsqueeze(0)
pix=pix.view([1,-1,576,768])
#print(pix.shape)

pred_classes, pred_direc = model(pix)
_, predicted = torch.max(pred_classes, 1)
print(predicted)

It works and output was "tensor([4])" .
But when I put green light image, it says it's "tensor([4])" almost every green light images.
I think it had problem on input parameter.
Please help..