microsoft/Recursive-Cascaded-Networks

Problem with training on my own dataset

Shubham0209 opened this issue · 2 comments

@zsyzzsoft
I am trying to run your code on my own liver dataset having pair of moving and fixed images. But I have kept the resolution to 28828896 and not 128128128 as in your code. I have made the necessary changes as well. But I am getting the following error:

Traceback (most recent call last):
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\framework\ops.py", line 1853, in _create_c_op
    c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension 1 in both shapes must be equal, but are 9 and 10. Shapes are [?,9,9,3] and [?,10,10,4]. for '{{node gaffdfrm/deform_stem_0/concat5}} = ConcatV2[N=3, T=DT_FLOAT, Tidx=DT_INT32](gaffdfrm/deform_stem_0/conv5_1_leakilyrectified, gaffdfrm/deform_stem_0/deconv5_rectified, gaffdfrm/deform_stem_0/upsamp6to5/conv3d_transpose, gaffdfrm/deform_stem_0/concat5/axis)' with input shapes: [?,9,9,3,256], [?,10,10,4,256], [?,10,10,4,3], [] and with computed input tensors: input[3] = <4>.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "train.py", line 244, in <module>
    main()
  File "train.py", line 78, in main
    framework = Framework(devices=gpus, image_size=image_size, segmentation_class_value=cfg.get('segmentation_class_value', None), fast_reconstruction = args.fast_reconstruction)
  File "C:\Users\shubh\Desktop\mir_test2\network\framework.py", line 93, in __init__
    self.predictions = self.network(*net_pls)
  File "C:\Users\shubh\Desktop\mir_test2\network\utils.py", line 111, in __call__
    return self.build(*args, **kwargs)
  File "C:\Users\shubh\Desktop\mir_test2\network\recursive_cascaded_networks.py", line 87, in build
    stem_result = stem(img1, stem_results[-1]['warped'])
  File "C:\Users\shubh\Desktop\mir_test2\network\utils.py", line 111, in __call__
    return self.build(*args, **kwargs)
  File "C:\Users\shubh\Desktop\mir_test2\network\base_networks.py", line 92, in build
    concat5 = tf.concat([conv5_1, deconv5, upsamp6to5], 4, 'concat5')
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1677, in concat
    return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 1206, in concat_v2
    _, _, _op, _outputs = _op_def_library._apply_op_helper(
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 748, in _apply_op_helper
    op = g._create_op_internal(op_type_name, inputs, dtypes=None,
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\framework\ops.py", line 3528, in _create_op_internal
    ret = Operation(
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\framework\ops.py", line 2015, in __init__
    self._c_op = _create_c_op(self._graph, node_def, inputs,
  File "C:\Users\shubh\anaconda3\envs\myenv\lib\site-packages\tensorflow\python\framework\ops.py", line 1856, in _create_c_op
    raise ValueError(str(e))
ValueError: Dimension 1 in both shapes must be equal, but are 9 and 10. Shapes are [?,9,9,3] and [?,10,10,4]. for '{{node gaffdfrm/deform_stem_0/concat5}} = ConcatV2[N=3, T=DT_FLOAT, Tidx=DT_INT32](gaffdfrm/deform_stem_0/conv5_1_leakilyrectified, gaffdfrm/deform_stem_0/deconv5_rectified, gaffdfrm/deform_stem_0/upsamp6to5/conv3d_transpose, gaffdfrm/deform_stem_0/concat5/axis)' with input shapes: [?,9,9,3,256], [?,10,10,4,256], [?,10,10,4,3], [] and with computed input tensors: input[3] = <4>.

Please help me on this.

The network assumes that input resolution is a multiple of 64. In your case, you may remove the 6-th level layers (conv6, conv6_1, pred6, and upsamp6to5) in the network as a workaround.

Thank you so much for your help