tensorflow/tensorrt

TF-TRT 1.x Workflow With A Frozen Graph

chennuo0125-HIT opened this issue · 1 comments

TF-TRT 1.x Workflow With A Frozen Graph

==old code : ==

74    def __init__(self, checkpoint_filename, input_name="images",
75                output_name="features"):
76        self.session = tf.Session()
77        with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
78             graph_def = tf.GraphDef()
79             graph_def.ParseFromString(file_handle.read())
80         tf.import_graph_def(graph_def, name="net")
81         self.input_var = tf.get_default_graph().get_tensor_by_name(
82             "net/%s:0" % input_name)
83         self.output_var = tf.get_default_graph().get_tensor_by_name(
84             "net/%s:0" % output_name)
85 
86         assert len(self.output_var.get_shape()) == 2
87         assert len(self.input_var.get_shape()) == 4
88         self.feature_dim = self.output_var.get_shape().as_list()[-1]
89         self.image_shape = self.input_var.get_shape().as_list()[1:]

==new code with tf_trt : ==

91    def __init__(self, checkpoint_filename, input_name="images",
92                 output_name="features"):
93        self.session = tf.Session()
94        with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
95            graph_def = tf.GraphDef()
96            graph_def.ParseFromString(file_handle.read())
97        converter = trt.TrtGraphConverter(input_graph_def=graph_def, nodes_blacklist=["images:0", "features:0"])
98        trt_graph = converter.convert()
99        tf.import_graph_def(trt_graph, name="net")
100       self.input_var = tf.get_default_graph().get_tensor_by_name(
101           "net/%s:0" % input_name)
102       self.output_var = tf.get_default_graph().get_tensor_by_name(
103           "net/%s:0" % output_name)
104
105       assert len(self.output_var.get_shape()) == 2
106       assert len(self.input_var.get_shape()) == 4
107       self.feature_dim = self.output_var.get_shape().as_list()[-1]
108        self.image_shape = self.input_var.get_shape().as_list()[1:]

error is follow :

Traceback (most recent call last):
  File "tools/generate_detections.py", line 235, in <module>
    main()
  File "tools/generate_detections.py", line 229, in main
    encoder = create_box_encoder(args.model, batch_size=32)
  File "tools/generate_detections.py", line 120, in create_box_encoder
    image_encoder = ImageEncoder(model_filename, input_name, output_name)
  File "tools/generate_detections.py", line 105, in __init__
    assert len(self.output_var.get_shape()) == 2
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_shape.py", line 827, in __len__
    raise ValueError("Cannot take the length of shape with unknown rank.")
ValueError: Cannot take the length of shape with unknown rank.

then i print input and output tensor :

print(self.input_var)
print(self.output_var)

the result is :
old code :

Tensor("net/images:0",  shape=(?, 128, 64, 3), dtype=uint8)
Tensor("net/features:0", shape=(?, 128), dtype=float32)

new code :

Tensor("net/images:0",  shape=(?, 128, 64, 3), dtype=uint8)
Tensor("net/features:0", dtype=float32)

the code source repository : https://github.com/nwojke/deep_sort/blob/280b8bdb255f223813ff4a8679f3e1321b08cdfc/tools/generate_detections.py#L71
please give me some advise how to use tf_trt properly, thank you