生成pb模型
xiaoran-xr opened this issue · 8 comments
xiaoran-xr commented
level2的例子中,我想用ckpt文件生成pb模型,但是需要指定输出名字。我是用ArgMax显示不存在。 使用FC2/wx_plus_b结果好像对不上。
czczup commented
level2的例子中,我想用ckpt文件生成pb模型,但是需要指定输出名字。我是用ArgMax显示不存在。 使用FC2/wx_plus_b结果好像对不上。
model函数里没有用argmax所以不存在,你可以用FC2/wx_plus_b拿到概率以后另外再求个argmax,结果对不上是不是没有加载训练好的模型。可以贴一下转pb的代码我看一下
xiaoran-xr commented
from model import Model
from tensorflow.python.framework.graph_util import convert_variables_to_constants
def compile_graph():
input_graph = tf.Graph()
sess = tf.Session(graph=input_graph)
with sess.graph.as_default():
model = Model()
sess.run(tf.global_variables_initializer())
input_graph_def = sess.graph.as_graph_def()
saver = tf.train.Saver(var_list=tf.global_variables())
saver.restore(sess, './model/model_level2.ckpt-71000')
output_graph_def = convert_variables_to_constants(
sess,
input_graph_def,
output_node_names=['FC2/wx_plus_b']
)
with tf.gfile.GFile('compile_model.pb', mode='wb') as gf:
gf.write(output_graph_def.SerializeToString())
if __name__ == '__main__':
compile_graph()`
xiaoran-xr commented
这个是预测的代码
sess = tf.compat.v1.Session(
graph=graph,
config=tf.compat.v1.ConfigProto(
allow_soft_placement=True,
))
graph_def = graph.as_graph_def()
with tf.io.gfile.GFile('compile_model.pb', "rb") as f:
graph_def_file = f.read()
graph_def.ParseFromString(graph_def_file)
with graph.as_default():
sess.run(tf.compat.v1.global_variables_initializer())
# 将图形从 graph_def 导入当前的默认 Graph.
_ = tf.import_graph_def(graph_def, name="")
dense_decoded = sess.graph.get_tensor_by_name("FC2/wx_plus_b:0")
x = sess.graph.get_tensor_by_name('input/X:0')
keep_prob = sess.graph.get_tensor_by_name('input/keep_prob:0')
sess.graph.finalize()
all_tensor = graph_def.node
pil_image = Image.open('111.png')
pil_image = pil_image.convert('L')
im = np.asarray(pil_image)
im = im.astype(np.float32)
image = cv2.resize(im,(44,60),)
image = image.swapaxes(0, 1)
image_batch = image[:, :, np.newaxis]/255
keep_prob_result = np.array([1], dtype=float)
print(type(keep_prob_result))
dense_decoded_code = sess.run(dense_decoded, feed_dict={
x: [image_batch],
keep_prob:keep_prob_result
})
for index, i in enumerate(dense_decoded_code[0]):
print(index, i)
print(max(dense_decoded_code[0]))`
xiaoran-xr commented
xiaoran-xr commented
<class 'numpy.ndarray'>
0 -2.8832908
1 -2.8340662
2 -3.760985
3 -5.361003
4 -3.6272085
5 -3.7321653
6 -4.2132726
7 -3.9730148
8 0.25813654
9 -2.8246596
10 -0.84458536
11 -1.7111077
12 -1.319597
13 -3.3274467
14 -1.9860088
15 -3.5664222
16 -2.1158872
17 -5.252863
18 -3.2950745
19 -1.6228321
20 -1.2509553
21 -3.2692134
22 -1.3156378
23 -3.1125956
24 -2.5402946
25 -2.462352
26 0.5903716
27 -0.9135409
28 -0.12570877
29 -2.043648
30 -2.5192776
31 -2.9804707
0.5903716
Process finished with exit code 0
czczup commented
from model import Model from tensorflow.python.framework.graph_util import convert_variables_to_constants def compile_graph(): input_graph = tf.Graph() sess = tf.Session(graph=input_graph) with sess.graph.as_default(): model = Model() sess.run(tf.global_variables_initializer()) input_graph_def = sess.graph.as_graph_def() saver = tf.train.Saver(var_list=tf.global_variables()) saver.restore(sess, './model/model_level2.ckpt-71000') output_graph_def = convert_variables_to_constants( sess, input_graph_def, output_node_names=['FC2/wx_plus_b'] ) with tf.gfile.GFile('compile_model.pb', mode='wb') as gf: gf.write(output_graph_def.SerializeToString()) if **name** == '**main**': compile_graph()`
这一段没问题,我在别的地方也是这样写
czczup commented
这个是预测的代码
sess = tf.compat.v1.Session( graph=graph, config=tf.compat.v1.ConfigProto( allow_soft_placement=True, )) graph_def = graph.as_graph_def() with tf.io.gfile.GFile('compile_model.pb', "rb") as f: graph_def_file = f.read() graph_def.ParseFromString(graph_def_file) with graph.as_default(): sess.run(tf.compat.v1.global_variables_initializer()) # 将图形从 graph_def 导入当前的默认 Graph. _ = tf.import_graph_def(graph_def, name="") dense_decoded = sess.graph.get_tensor_by_name("FC2/wx_plus_b:0") x = sess.graph.get_tensor_by_name('input/X:0') keep_prob = sess.graph.get_tensor_by_name('input/keep_prob:0') sess.graph.finalize() all_tensor = graph_def.node pil_image = Image.open('111.png') pil_image = pil_image.convert('L') im = np.asarray(pil_image) im = im.astype(np.float32) image = cv2.resize(im,(44,60),) image = image.swapaxes(0, 1) image_batch = image[:, :, np.newaxis]/255 keep_prob_result = np.array([1], dtype=float) print(type(keep_prob_result)) dense_decoded_code = sess.run(dense_decoded, feed_dict={ x: [image_batch], keep_prob:keep_prob_result }) for index, i in enumerate(dense_decoded_code[0]): print(index, i) print(max(dense_decoded_code[0]))`
with graph.as_default():
sess.run(tf.compat.v1.global_variables_initializer())
# 将图形从 graph_def 导入当前的默认 Graph.
_ = tf.import_graph_def(graph_def, name="")
你把这里的sess.run(tf.compat.v1.global_variables_initializer())删掉再试试吧
xiaoran-xr commented
我试了下,把sess.run(tf.compat.v1.global_variables_initializer())删掉跟之前的结果还是一样的。