edvardHua/PoseEstimationForMobile

求助:coco数据集训练模型报错

cncowboy opened this issue · 0 comments

edvardHua,您好
用coco 数据集重新训练模型:
1、mv_cpm.cfg文件修改内容:
n_kpoints: 17
2、src/dataset.py文件修改内容
TRAIN_JSON = "person_keypoints_train2017.json"
VALID_JSON = "person_keypoints_val2017.json"
3、src/dataset_augment.py文件修改内容:
`class CocoPart(Enum):
Nose = 1
LEye = 2
REye = 3
LEar = 4
REar = 5
LShoulder = 6
RShoulder = 7
LElbow = 8
RElbow = 9
LWrist = 10
RWrist = 11
LHip = 12
RHip = 13
LKnee = 14
RKnee = 15
LAnkle = 16
RAnkle = 17

def set_network_input_wh(w, h):
global _network_w, _network_h
@@ -122,10 +124,25 @@ def pose_flip(meta):
img = cv2.flip(img, 1)

 # flip meta
flip_list = [
    CocoPart.Nose,
    CocoPart.LEye,
    CocoPart.REye,
    CocoPart.LEar,
    CocoPart.REar,
    CocoPart.LShoulder,
    CocoPart.RShoulder,
    CocoPart.LElbow,
    CocoPart.RElbow,
    CocoPart.LWrist,
    CocoPart.RWrist,
    CocoPart.LHip,
    CocoPart.RHip,
    CocoPart.LKnee,
    CocoPart.RKnee,
    CocoPart.LAnkle,
    CocoPart.RAnkle
]`

4、src/network_mv2_cpm.py文件:
N_KPOINTS = 17

最后运行模型训练命令
python3 src/train.py experiments/mv2_cpm.cfg
结果报错:
`preparing annotation from: /data5/mscoco/annotations/person_keypoints_train2017.json
loading annotations into memory...
Done (t=10.02s)
creating index...
index created!
preparing annotation from: /data5/mscoco/annotations/person_keypoints_val2017.json
loading annotations into memory...
Done (t=0.32s)
creating index...
index created!
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/common_shapes.py", line 686, in _call_cpp_shape_fn_impl
input_tensors_as_shapes, status)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 473, in exit
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension 1 in both shapes must be equal, but are 46 and 48 for 'GPU_0/MobilenetV2/concat' (op: 'ConcatV2') with input shapes: [?,46,46,12], [?,46,46,18], [?,46,46,24], [?,46,46,48], [?,48,48,72], [] and with computed input tensors: input[5] = <3>.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "src/train.py", line 250, in
tf.app.run()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "src/train.py", line 149, in main
loss, last_heat_loss, pred_heat = get_loss_and_output(params['model'], params['batchsize'], input_image, input_heat, reuse_variable)
File "src/train.py", line 43, in get_loss_and_output
_, pred_heatmaps_all = get_network(model, input_image, True)
File "/workspace/src/networks.py", line 12, in get_network
net, loss = network_mv2_cpm.build_network(input, trainable)
File "/workspace/src/network_mv2_cpm.py", line 91, in build_network
, axis=3)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 1099, in concat
return gen_array_ops._concat_v2(values=values, axis=axis, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 706, in _concat_v2
"ConcatV2", values=values, axis=axis, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2958, in create_op
set_shapes_for_outputs(ret)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2209, in set_shapes_for_outputs
shapes = shape_func(op)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2159, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/common_shapes.py", line 627, in call_cpp_shape_fn
require_shape_fn)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/common_shapes.py", line 691, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Dimension 1 in both shapes must be equal, but are 46 and 48 for 'GPU_0/MobilenetV2/concat' (op: 'ConcatV2') with input shapes: [?,46,46,12], [?,46,46,18], [?,46,46,24], [?,46,46,48], [?,48,48,72], [] and with computed input tensors: input[5] = <3>.`