Attempting to test pretrained model on my own image data pairs - weird results
9thDimension opened this issue · 2 comments
9thDimension commented
I modified the demo.py
to allow a user to run inference on an image pair of their choosing:
from __future__ import print_function, division
import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from model.cnn_geometric_model import CNNGeometric
from data import pf_dataset
from data.download_datasets import download_PF_willow
from image.normalization import NormalizeImageDict, normalize_image
from util.torch_util import BatchTensorToVars, str_to_bool
from geotnf.transformation import GeometricTnf
from geotnf.point_tnf import *
import matplotlib.pyplot as plt
from skimage import io
from collections import OrderedDict
import cv2
# for compatibility with Python 2
try:
input = raw_input
except NameError:
pass
"""
Script to demonstrate evaluation on a trained model as presented in the CNNGeometric CVPR'17 paper
on the ProposalFlow dataset
"""
print('CNNGeometric PF demo script')
# Argument parsing
parser = argparse.ArgumentParser(description='CNNGeometric PyTorch implementation')
# Paths
parser.add_argument('--model-aff', type=str,
default='trained_models/best_pascal_checkpoint_adam_affine_grid_loss_resnet_random.pth.tar',
help='Trained affine model filename')
parser.add_argument('--model-tps', type=str,
default='trained_models/best_pascal_checkpoint_adam_tps_grid_loss_resnet_random.pth.tar',
help='Trained TPS model filename')
parser.add_argument('--feature-extraction-cnn', type=str, default='resnet101',
help='Feature extraction architecture: vgg/resnet101')
# parser.add_argument('--pf-path', type=str, default='datasets/PF-dataset', help='Path to PF dataset')
parser.add_argument('--pf-path', type=str, default='datasets/PF-dataset', help='Path to PF dataset')
parser.add_argument('imgpath_source', help='path to source image')
parser.add_argument('imgpath_target', help='path to the target image')
args = parser.parse_args()
use_cuda = torch.cuda.is_available()
do_aff = not args.model_aff == ''
do_tps = not args.model_tps == ''
# Download dataset if needed
download_PF_willow('datasets/')
# Create model
print('Creating CNN model...')
if do_aff:
model_aff = CNNGeometric(use_cuda=use_cuda, geometric_model='affine',
feature_extraction_cnn=args.feature_extraction_cnn)
if do_tps:
model_tps = CNNGeometric(use_cuda=use_cuda, geometric_model='tps',
feature_extraction_cnn=args.feature_extraction_cnn)
# Load trained weights
print('Loading trained model weights...')
if do_aff:
checkpoint = torch.load(args.model_aff, map_location=lambda storage, loc: storage)
checkpoint['state_dict'] = OrderedDict(
[(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
model_aff.load_state_dict(checkpoint['state_dict'])
if do_tps:
checkpoint = torch.load(args.model_tps, map_location=lambda storage, loc: storage)
checkpoint['state_dict'] = OrderedDict(
[(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
model_tps.load_state_dict(checkpoint['state_dict'])
# Dataset and dataloader
# dataset = PFDataset(csv_file=os.path.join(args.pf_path, 'test_pairs_pf.csv'),
# training_image_path=args.pf_path,
# transform=NormalizeImageDict(['source_image', 'target_image']))
# dataloader = DataLoader(dataset, batch_size=1,
# shuffle=True, num_workers=4)
batchTensorToVars = BatchTensorToVars(use_cuda=use_cuda)
affineTnf = GeometricTnf(out_h=240, out_w=240, use_cuda = False)
def get_image(img_path):
# img_name = os.path.join(self.training_image_path, img_name_list[idx])
image = io.imread(img_path)
# get image size
im_size = np.asarray(image.shape)
# convert to torch Variable
image = np.expand_dims(image.transpose((2, 0, 1)), 0)
image = torch.Tensor(image.astype(np.float32))
image_var = Variable(image, requires_grad=False)
# Resize image using bilinear sampling with identity affine tnf
# image = affineTnf(image_var).data.squeeze(0)
image = affineTnf(image_var).data
im_size = torch.Tensor(im_size.astype(np.float32))
return (image, im_size)
# Instantiate point transformer
pt = PointTnf(use_cuda=use_cuda)
# Instatiate image transformers
tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
IMPATH_SOURCE = args.imgpath_source
IMPATH_TARGET = args.imgpath_target
src_image, src_im_size = get_image(IMPATH_SOURCE)
tgt_image, tgt_im_size = get_image(IMPATH_TARGET)
src_image = normalize_image(src_image, forward=True)
tgt_image = normalize_image(tgt_image, forward=True)
batch = {'source_image': src_image,
'target_image': tgt_image,
'source_im_size': src_im_size,
'target_im_size': tgt_im_size}
batch = batchTensorToVars(batch)
source_im_size = batch['source_im_size']
target_im_size = batch['target_im_size']
# these are not needed, right?
# source_points = batch['source_points']
# target_points = batch['target_points']
# warp points with estimated transformations
# target_points_norm = PointsToUnitCoords(target_points, target_im_size)
if do_aff:
model_aff.eval()
if do_tps:
model_tps.eval()
# Evaluate models
if do_aff:
theta_aff = model_aff(batch)
warped_image_aff = affTnf(batch['source_image'], theta_aff.view(-1, 2, 3))
if do_tps:
theta_tps = model_tps(batch)
warped_image_tps = tpsTnf(batch['source_image'], theta_tps)
if do_aff and do_tps:
theta_aff_tps = model_tps({'source_image': warped_image_aff, 'target_image': batch['target_image']})
warped_image_aff_tps = tpsTnf(warped_image_aff, theta_aff_tps)
# Un-normalize images and convert to numpy
source_image = normalize_image(batch['source_image'], forward=False)
source_image = source_image.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy()
target_image = normalize_image(batch['target_image'], forward=False)
target_image = target_image.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy()
if do_aff:
warped_image_aff = normalize_image(warped_image_aff, forward=False)
warped_image_aff = warped_image_aff.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy()
if do_tps:
warped_image_tps = normalize_image(warped_image_tps, forward=False)
warped_image_tps = warped_image_tps.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy()
if do_aff and do_tps:
warped_image_aff_tps = normalize_image(warped_image_aff_tps, forward=False)
warped_image_aff_tps = warped_image_aff_tps.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy()
# check if display is available
exit_val = os.system('python -c "import matplotlib.pyplot as plt;plt.figure()" > /dev/null 2>&1')
display_avail = exit_val == 0
# display_avail = False
if display_avail:
N_subplots = 2 + int(do_aff) + int(do_tps) + int(do_aff and do_tps)
fig, axs = plt.subplots(1, N_subplots)
# N.B. I had to recast all of these to np.uint8s because it was failing otherwise for some reason...
axs[0].imshow(np.uint8(source_image))
axs[0].set_title('src')
axs[1].imshow(np.uint8(target_image))
axs[1].set_title('tgt')
subplot_idx = 2
if do_aff:
axs[subplot_idx].imshow(np.uint8(warped_image_aff))
axs[subplot_idx].set_title('aff')
subplot_idx += 1
if do_tps:
axs[subplot_idx].imshow(np.uint8(warped_image_tps))
axs[subplot_idx].set_title('tps')
subplot_idx += 1
if do_aff and do_tps:
axs[subplot_idx].imshow(np.uint8(warped_image_aff_tps))
axs[subplot_idx].set_title('aff+tps')
for i in range(N_subplots):
axs[i].axis('off')
print('Showing results. Close figure window to continue')
plt.show()
else:
print('No display found. Writing results to:')
fn_src = 'source.png'
print(fn_src)
io.imsave(fn_src, source_image)
fn_tgt = 'target.png'
print(fn_tgt)
io.imsave(fn_tgt, target_image)
if do_aff:
fn_aff = 'result_aff.png'
print(fn_aff)
io.imsave(fn_aff, warped_image_aff)
if do_tps:
fn_tps = 'result_tps.png'
print(fn_tps)
io.imsave(fn_tps, warped_image_tps)
if do_aff and do_tps:
fn_aff_tps = 'result_aff_tps.png'
print(fn_aff_tps)
io.imsave(fn_aff_tps, warped_image_aff_tps)
Which the user can invoke as follows:
$ python demo2.py /path/to/source/image.jpg /path/to/target/image.jpg
But I'm getting weird results. Such as even for a provided input of identical source and target images, it predicts some weird transform.
Perhaps I've messed up the modifications in some way...
lixiaolusunshine commented
I want to know that in the train.csv file what the meaning that A22,A21,A12,A11 respectively represent?
richieYT-wan commented
Hi, where did you get the pre-trained models (best-pascal) for PASCAL data set?
Only the pre-trained VGG-streetviews are available in the readme.md :(