czczup/Captcha-Recognition

生成pb模型

xiaoran-xr opened this issue · 8 comments

level2的例子中,我想用ckpt文件生成pb模型,但是需要指定输出名字。我是用ArgMax显示不存在。 使用FC2/wx_plus_b结果好像对不上。

level2的例子中,我想用ckpt文件生成pb模型,但是需要指定输出名字。我是用ArgMax显示不存在。 使用FC2/wx_plus_b结果好像对不上。

model函数里没有用argmax所以不存在,你可以用FC2/wx_plus_b拿到概率以后另外再求个argmax,结果对不上是不是没有加载训练好的模型。可以贴一下转pb的代码我看一下

from model import Model
from tensorflow.python.framework.graph_util import convert_variables_to_constants


def compile_graph():
    input_graph = tf.Graph()
    sess = tf.Session(graph=input_graph)

    with sess.graph.as_default():
        model = Model()
        sess.run(tf.global_variables_initializer())
        input_graph_def = sess.graph.as_graph_def()
        saver = tf.train.Saver(var_list=tf.global_variables())
        saver.restore(sess, './model/model_level2.ckpt-71000')

    output_graph_def = convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names=['FC2/wx_plus_b']
    )
    with tf.gfile.GFile('compile_model.pb', mode='wb') as gf:
        gf.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
    compile_graph()`

这个是预测的代码

sess = tf.compat.v1.Session(
    graph=graph,
    config=tf.compat.v1.ConfigProto(
    allow_soft_placement=True,
))
graph_def = graph.as_graph_def()
with tf.io.gfile.GFile('compile_model.pb', "rb") as f:
    graph_def_file = f.read()
graph_def.ParseFromString(graph_def_file)
with graph.as_default():
    sess.run(tf.compat.v1.global_variables_initializer())
    # 将图形从 graph_def 导入当前的默认 Graph.
    _ = tf.import_graph_def(graph_def, name="")
dense_decoded = sess.graph.get_tensor_by_name("FC2/wx_plus_b:0")
x = sess.graph.get_tensor_by_name('input/X:0')
keep_prob = sess.graph.get_tensor_by_name('input/keep_prob:0')
sess.graph.finalize()
all_tensor = graph_def.node


pil_image = Image.open('111.png')
pil_image = pil_image.convert('L')
im = np.asarray(pil_image)
im = im.astype(np.float32)
image = cv2.resize(im,(44,60),)
image = image.swapaxes(0, 1)

image_batch = image[:, :, np.newaxis]/255

keep_prob_result = np.array([1], dtype=float)
print(type(keep_prob_result))

dense_decoded_code = sess.run(dense_decoded, feed_dict={
        x: [image_batch],
        keep_prob:keep_prob_result
    })
for index, i in enumerate(dense_decoded_code[0]):
    print(index, i)

print(max(dense_decoded_code[0]))`

111
这个是图片

<class 'numpy.ndarray'>
0 -2.8832908
1 -2.8340662
2 -3.760985
3 -5.361003
4 -3.6272085
5 -3.7321653
6 -4.2132726
7 -3.9730148
8 0.25813654
9 -2.8246596
10 -0.84458536
11 -1.7111077
12 -1.319597
13 -3.3274467
14 -1.9860088
15 -3.5664222
16 -2.1158872
17 -5.252863
18 -3.2950745
19 -1.6228321
20 -1.2509553
21 -3.2692134
22 -1.3156378
23 -3.1125956
24 -2.5402946
25 -2.462352
26 0.5903716
27 -0.9135409
28 -0.12570877
29 -2.043648
30 -2.5192776
31 -2.9804707
0.5903716

Process finished with exit code 0
from model import Model
from tensorflow.python.framework.graph_util import convert_variables_to_constants

def compile_graph():
input_graph = tf.Graph()
sess = tf.Session(graph=input_graph)


with sess.graph.as_default():
    model = Model()
    sess.run(tf.global_variables_initializer())
    input_graph_def = sess.graph.as_graph_def()
    saver = tf.train.Saver(var_list=tf.global_variables())
    saver.restore(sess, './model/model_level2.ckpt-71000')

output_graph_def = convert_variables_to_constants(
    sess,
    input_graph_def,
    output_node_names=['FC2/wx_plus_b']
)
with tf.gfile.GFile('compile_model.pb', mode='wb') as gf:
    gf.write(output_graph_def.SerializeToString())


if **name** == '**main**':
compile_graph()`

这一段没问题,我在别的地方也是这样写

这个是预测的代码

sess = tf.compat.v1.Session(
    graph=graph,
    config=tf.compat.v1.ConfigProto(
    allow_soft_placement=True,
))
graph_def = graph.as_graph_def()
with tf.io.gfile.GFile('compile_model.pb', "rb") as f:
    graph_def_file = f.read()
graph_def.ParseFromString(graph_def_file)
with graph.as_default():
    sess.run(tf.compat.v1.global_variables_initializer())
    # 将图形从 graph_def 导入当前的默认 Graph.
    _ = tf.import_graph_def(graph_def, name="")
dense_decoded = sess.graph.get_tensor_by_name("FC2/wx_plus_b:0")
x = sess.graph.get_tensor_by_name('input/X:0')
keep_prob = sess.graph.get_tensor_by_name('input/keep_prob:0')
sess.graph.finalize()
all_tensor = graph_def.node


pil_image = Image.open('111.png')
pil_image = pil_image.convert('L')
im = np.asarray(pil_image)
im = im.astype(np.float32)
image = cv2.resize(im,(44,60),)
image = image.swapaxes(0, 1)

image_batch = image[:, :, np.newaxis]/255

keep_prob_result = np.array([1], dtype=float)
print(type(keep_prob_result))

dense_decoded_code = sess.run(dense_decoded, feed_dict={
        x: [image_batch],
        keep_prob:keep_prob_result
    })
for index, i in enumerate(dense_decoded_code[0]):
    print(index, i)

print(max(dense_decoded_code[0]))`
with graph.as_default():
    sess.run(tf.compat.v1.global_variables_initializer())
    # 将图形从 graph_def 导入当前的默认 Graph.
    _ = tf.import_graph_def(graph_def, name="")

你把这里的sess.run(tf.compat.v1.global_variables_initializer())删掉再试试吧

我试了下,把sess.run(tf.compat.v1.global_variables_initializer())删掉跟之前的结果还是一样的。