tensorflow/tensorrt

tftrt error at sess.run for input

nyanmn opened this issue · 0 comments

The following is sample from tftrt link for tftrt code for frozen graph.

import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
with tf.Session() as sess:
    # First deserialize your frozen graph:
    with tf.gfile.GFile(“/path/to/your/frozen/graph.pb”, ‘rb’) as f:
        frozen_graph = tf.GraphDef()
        frozen_graph.ParseFromString(f.read())
    # Now you can create a TensorRT inference graph from your
    # frozen graph:
    converter = trt.TrtGraphConverter(
	    input_graph_def=frozen_graph,
	    nodes_blacklist=['logits', 'classes']) #output nodes
    trt_graph = converter.convert()
    # Import the TensorRT graph into a new graph and run:
    output_node = tf.import_graph_def(
        trt_graph,
        return_elements=['logits', 'classes'])
    sess.run(output_node)

For my application, I have error at input image.

with tf.Session(config=config) as sess:
    # First deserialize your frozen graph:
    with tf.gfile.GFile(args.input_graph_def, 'rb') as f:
        frozen_graph = tf.GraphDef()
        frozen_graph.ParseFromString(f.read())
    # Now you can create a TensorRT inference graph from your
    # frozen graph:
    converter = trt.TrtGraphConverter(input_graph_def=frozen_graph, max_workspace_size_bytes=(1<<10), precision_mode=args.precision, maximum_cached_engines=100, minimum_segment_size = 100, nodes_blacklist=['input:0', 'd_predictions:0']) 
    trt_graph = converter.convert()
    # Import the TensorRT graph into a new graph and run:
    output_node = tf.import_graph_def(trt_graph, return_elements=['input:0', 'd_predictions:0'])
    for jpeg_file in jpeg_files:
      img = cv2.imread(args.data_dir+jpeg_file) 
      img = cv2.resize(img, (args.input_w,args.input_h))
      output = sess.run(output_node, feed_dict={input:[img]})

If input is fed above, I have error as TypeError: Cannot interpret feed_dict key as Tensor: Can not convert a builtin_function_or_method into a Tensor.

If input is fed as below
output = sess.run(output_node, feed_dict={[img]})
I have error as TypeError: unhashable type: 'list'

What could be correct format?