训练时出错
arixlin opened this issue · 1 comments
arixlin commented
Traceback (most recent call last):
File "train.py", line 71, in <module>
validation_data=val_ds)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1102, in fit
tmp_logs = self.train_function(iterator)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 796, in __call__
result = self._call(*args, **kwds)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 839, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 712, in _initialize
*args, **kwds))
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3181, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 614, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:
/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:809 train_function *
return step_function(self, iterator)
/mnt/d/github/CRNN.tf2/metrics.py:26 update_state *
values = tf.math.reduce_any(tf.math.not_equal(y_true, y_pred), axis=1)
/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper **
return target(*args, **kwargs)
/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1674 not_equal
return gen_math_ops.not_equal(x, y, name=name)
/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py:6517 not_equal
name=name)
/mnt/e/ubuntu/anaconda3/envs/tflite/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:537 _apply_op_helper
inferred_from[input_arg.type_attr]))
TypeError: Input 'y' of 'NotEqual' Op has type float32 that does not match type int64 of argument 'x'.
麻烦帮忙看下这个错误, 我debug后看到 y_true 是None, 谢谢
arixlin commented
修改metrics.py
def update_state(self, y_true, y_pred, sample_weight=None):
batch_size = tf.shape(y_true)[0]
max_width = tf.maximum(tf.shape(y_true)[1], tf.shape(y_pred)[1])
logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
decoded, _ = tf.nn.ctc_greedy_decoder(
inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
sequence_length=logit_length)
y_true = tf.sparse.reset_shape(y_true, [batch_size, max_width])
y_pred = tf.sparse.reset_shape(decoded[0], [batch_size, max_width])
y_true = tf.sparse.to_dense(y_true, default_value=-1)
y_pred = tf.sparse.to_dense(y_pred, default_value=-1)
y_pred = tf.cast(y_pred, tf.float32)
y_true = tf.cast(y_true, tf.float32)
values = tf.math.reduce_any(tf.math.not_equal(y_true, y_pred), axis=1)
values = tf.cast(values, tf.float32)
values = tf.reduce_sum(values)
batch_size = tf.cast(batch_size, tf.float32)
self.total.assign_add(batch_size)
self.count.assign_add(batch_size - values)