ray075hl/attention-ocr-toy-example

attention是不是对sequence长度比较敏感?

Closed this issue · 5 comments

楼主你好,我测试了一下您提供的代码,在3位数图像精度比较高,但是4位数和5位数就比较差了。
这个问题有什么好的解决办法吗?

@fendaq 我在做中文识别的时候也发现了这个问题 , 目前也在找原因。
说一下我的一点目前的实验经验吧
1.首先传入attetnion decoder的initial state最好用双向rnn编码出来的state,而不要用全零的向量,后者容易导致第一次解码的结果错误
2.我觉得attetnion似乎对噪声(模糊)比较敏感,经常出现定位不准,导致遗漏,而一起训练的ctc就没有这个问题

至于你说的长度, 我发现清晰的图片一行十几个字也是可以完全识别正确的,所以我认为是噪声敏感,个人见解,经供参考

1.首先传入attetnion decoder的initial state最好用双向rnn编码出来的state,而不要用全零的向量,后者容易导致第一次解码的结果错误

 enc_outputs, encoder_state= tf.nn.bidirectional_dynamic_rnn(cell_fw=cell,
                                                             cell_bw=cell,
                                                             inputs=cnn_out,
                                                             dtype=tf.float32)
...............................
    initial_state = attn_cell.zero_state(BATCH_SIZE, tf.float32).clone(cell_state=encoder_state)

    decoder = tf.contrib.seq2seq.BasicDecoder(
        cell= attn_cell, helper=helper,
        initial_state = initial_state,
        output_layer=output_layer)

是这样修改吗?为什么我这边会报错,没法运行。。

Traceback (most recent call last):
File "/data/attention-ocr-toy-example/attention_model.py", line 195, in
main()
File "/data/attention-ocr-toy-example/attention_model.py", line 190, in main
loss, train_one_step, train_decode_result, pred_decode_result = build_compute_graph()
File "/data/attention-ocr-toy-example/attention_model.py", line 115, in build_compute_graph
train_outputs = decode(train_helper, train_output_embed,enc_state, 'decode')
File "/data/attention-ocr-toy-example/attention_model.py", line 99, in decode
impute_finished=True, maximum_iterations=MAXIMUM__DECODE_ITERATIONS)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode
swap_memory=swap_memory)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2640, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2590, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body
decoder_finished) = decoder.step(time, inputs, state)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 138, in step
cell_outputs, cell_state = self._cell(inputs, state)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in call
return super(RNNCell, self).call(inputs, state)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 575, in call
outputs = self.call(inputs, *args, **kwargs)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py", line 1295, in call
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in call
return super(RNNCell, self).call(inputs, state)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 575, in call
outputs = self.call(inputs, *args, **kwargs)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 320, in call
kernel_initializer=self._kernel_initializer)
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1154, in init
shapes = [a.get_shape() for a in args]
File "/home/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1154, in
shapes = [a.get_shape() for a in args]
AttributeError: 'tuple' object has no attribute 'get_shape'

@fendaq clone(cell_state=encoder_state[0])

enc_outputs, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell, 
                                                                                               cell_bw=cell,
                                                                                               inputs=cnn_out,
                                                                                               dtype=tf.float32)

decoder = tf.contrib.seq2seq.BasicDecoder(cell=attn_cell, helper=helper,
                  initial_state=
                  attn_cell.zero_state(dtype=tf.float32,batch_size=batch_size).clone(cell_state=enc_state[0]),
                  output_layer=output_layer)

此外 訓練的時候用

train_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(output_embed,att_train_length,
embeddings,sample_rate)

函數 具體可以參考issue #3

nmt并没有使用ScheduledEmbeddingTrainingHelper,但是可以正确预测,所以我觉得ScheduledEmbeddingTrainingHelper是非必须的。