FLming/CRNN.tf2

训练时出错

arixlin opened this issue · 1 comments

Traceback (most recent call last):
  File "train.py", line 71, in <module>
    validation_data=val_ds)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1102, in fit
    tmp_logs = self.train_function(iterator)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 796, in __call__
    result = self._call(*args, **kwds)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 839, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 712, in _initialize
    *args, **kwds))
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3181, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 614, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:

    /mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:809 train_function  *
        return step_function(self, iterator)
    /mnt/d/github/CRNN.tf2/metrics.py:26 update_state  *
        values = tf.math.reduce_any(tf.math.not_equal(y_true, y_pred), axis=1)
    /mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    /mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1674 not_equal
        return gen_math_ops.not_equal(x, y, name=name)
    /mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py:6517 not_equal
        name=name)
    /mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:537 _apply_op_helper
        inferred_from[input_arg.type_attr]))

    TypeError: Input 'y' of 'NotEqual' Op has type float32 that does not match type int64 of argument 'x'.

麻烦帮忙看下这个错误, 我debug后看到 y_true 是None, 谢谢

修改metrics.py

    def update_state(self, y_true, y_pred, sample_weight=None):
        batch_size = tf.shape(y_true)[0]
        max_width = tf.maximum(tf.shape(y_true)[1], tf.shape(y_pred)[1])
        logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])        
        decoded, _ = tf.nn.ctc_greedy_decoder(
            inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
            sequence_length=logit_length)
        y_true = tf.sparse.reset_shape(y_true, [batch_size, max_width])
        y_pred = tf.sparse.reset_shape(decoded[0], [batch_size, max_width])
        y_true = tf.sparse.to_dense(y_true, default_value=-1)
        y_pred = tf.sparse.to_dense(y_pred, default_value=-1)
        y_pred = tf.cast(y_pred, tf.float32)
        y_true = tf.cast(y_true, tf.float32)
        values = tf.math.reduce_any(tf.math.not_equal(y_true, y_pred), axis=1)
        values = tf.cast(values, tf.float32)
        values = tf.reduce_sum(values)
        batch_size = tf.cast(batch_size, tf.float32)
        self.total.assign_add(batch_size)
        self.count.assign_add(batch_size - values)