ValueError: Dimension must be 5 but is 4 for '{{node attention_11/transpose_4}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](attention_11/truediv, attention_11/transpose_4/perm)' with input shapes: [?,8,?,8,?], [4].
Yumeka999 opened this issue · 1 comments
Yumeka999 commented
tf=2.6.0
keras=2.6.0
执行xy = Attention(8, 16)([y, x, x, x_mask])这一步时报错,
进入函数发现a = K.permute_dimensions(a, (0, 3, 2, 1))这里报错了
ValueError: in user code:
<ipython-input-106-d9f1e622f23a>:50 call *
a = K.permute_dimensions(a, (0, 3, 2, 1))
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:206 wrapper **
return target(*args, **kwargs)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/keras/backend.py:3133 permute_dimensions
return tf.compat.v1.transpose(x, perm=pattern)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py:2309 transpose
return transpose_fn(a, perm, name=name)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py:11659 transpose
"Transpose", x=x, perm=perm, name=name)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper
attrs=attr_protos, op_def=op_def)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:3569 _create_op_internal
op_def=op_def)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:2042 __init__
control_input_ops, op_def)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:1883 _create_c_op
raise ValueError(str(e))
ValueError: Dimension must be 5 but is 4 for '{{node attention_11/transpose_4}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](attention_11/truediv, attention_11/transpose_4/perm)' with input shapes: [?,8,?,8,?], [4].
Kalafinaian commented
有解决方案,在源代码中
将a = K.batch_dot(qw, kw, [3, 3]) / self.key_size0.5
改为a = tf.einsum('bjhd,bkhd->bhjk', qw, kw) / self.key_size0.5
将o = K.batch_dot(a, vw, [3, 2])
改为o = tf.einsum('bhjk,bkhd->bjhd', a, vw)
主要是keras版本变动导致广播方式不一样了