bojone/seq2seq

ValueError: Dimension must be 5 but is 4 for '{{node attention_11/transpose_4}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](attention_11/truediv, attention_11/transpose_4/perm)' with input shapes: [?,8,?,8,?], [4].

Yumeka999 opened this issue · 1 comments

tf=2.6.0
keras=2.6.0

执行xy = Attention(8, 16)([y, x, x, x_mask])这一步时报错,
进入函数发现a = K.permute_dimensions(a, (0, 3, 2, 1))这里报错了

ValueError: in user code:

<ipython-input-106-d9f1e622f23a>:50 call  *
    a = K.permute_dimensions(a, (0, 3, 2, 1))
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:206 wrapper  **
    return target(*args, **kwargs)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/keras/backend.py:3133 permute_dimensions
    return tf.compat.v1.transpose(x, perm=pattern)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
    return target(*args, **kwargs)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py:2309 transpose
    return transpose_fn(a, perm, name=name)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py:11659 transpose
    "Transpose", x=x, perm=perm, name=name)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper
    attrs=attr_protos, op_def=op_def)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:3569 _create_op_internal
    op_def=op_def)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:2042 __init__
    control_input_ops, op_def)
/root/anaconda3/envs/py364/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:1883 _create_c_op
    raise ValueError(str(e))

ValueError: Dimension must be 5 but is 4 for '{{node attention_11/transpose_4}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](attention_11/truediv, attention_11/transpose_4/perm)' with input shapes: [?,8,?,8,?], [4].

有解决方案,在源代码中

将a = K.batch_dot(qw, kw, [3, 3]) / self.key_size0.5
改为a = tf.einsum('bjhd,bkhd->bhjk', qw, kw) / self.key_size
0.5

将o = K.batch_dot(a, vw, [3, 2])
改为o = tf.einsum('bhjk,bkhd->bjhd', a, vw)
主要是keras版本变动导致广播方式不一样了