TF-TRT 1.x Workflow With A Frozen Graph
chennuo0125-HIT opened this issue · 1 comments
chennuo0125-HIT commented
TF-TRT 1.x Workflow With A Frozen Graph
chennuo0125-HIT commented
==old code : ==
74 def __init__(self, checkpoint_filename, input_name="images",
75 output_name="features"):
76 self.session = tf.Session()
77 with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
78 graph_def = tf.GraphDef()
79 graph_def.ParseFromString(file_handle.read())
80 tf.import_graph_def(graph_def, name="net")
81 self.input_var = tf.get_default_graph().get_tensor_by_name(
82 "net/%s:0" % input_name)
83 self.output_var = tf.get_default_graph().get_tensor_by_name(
84 "net/%s:0" % output_name)
85
86 assert len(self.output_var.get_shape()) == 2
87 assert len(self.input_var.get_shape()) == 4
88 self.feature_dim = self.output_var.get_shape().as_list()[-1]
89 self.image_shape = self.input_var.get_shape().as_list()[1:]
==new code with tf_trt : ==
91 def __init__(self, checkpoint_filename, input_name="images",
92 output_name="features"):
93 self.session = tf.Session()
94 with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
95 graph_def = tf.GraphDef()
96 graph_def.ParseFromString(file_handle.read())
97 converter = trt.TrtGraphConverter(input_graph_def=graph_def, nodes_blacklist=["images:0", "features:0"])
98 trt_graph = converter.convert()
99 tf.import_graph_def(trt_graph, name="net")
100 self.input_var = tf.get_default_graph().get_tensor_by_name(
101 "net/%s:0" % input_name)
102 self.output_var = tf.get_default_graph().get_tensor_by_name(
103 "net/%s:0" % output_name)
104
105 assert len(self.output_var.get_shape()) == 2
106 assert len(self.input_var.get_shape()) == 4
107 self.feature_dim = self.output_var.get_shape().as_list()[-1]
108 self.image_shape = self.input_var.get_shape().as_list()[1:]
error is follow :
Traceback (most recent call last):
File "tools/generate_detections.py", line 235, in <module>
main()
File "tools/generate_detections.py", line 229, in main
encoder = create_box_encoder(args.model, batch_size=32)
File "tools/generate_detections.py", line 120, in create_box_encoder
image_encoder = ImageEncoder(model_filename, input_name, output_name)
File "tools/generate_detections.py", line 105, in __init__
assert len(self.output_var.get_shape()) == 2
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_shape.py", line 827, in __len__
raise ValueError("Cannot take the length of shape with unknown rank.")
ValueError: Cannot take the length of shape with unknown rank.
then i print input and output tensor :
print(self.input_var)
print(self.output_var)
the result is :
old code :
Tensor("net/images:0", shape=(?, 128, 64, 3), dtype=uint8)
Tensor("net/features:0", shape=(?, 128), dtype=float32)
new code :
Tensor("net/images:0", shape=(?, 128, 64, 3), dtype=uint8)
Tensor("net/features:0", dtype=float32)
the code source repository : https://github.com/nwojke/deep_sort/blob/280b8bdb255f223813ff4a8679f3e1321b08cdfc/tools/generate_detections.py#L71
please give me some advise how to use tf_trt properly, thank you