Error running tutorial script (1-generation_visualisation.py) in newer versions of tensorflow
Closed this issue · 2 comments
Hi @BBillot
While the official recommended (according to the requirements.txt) TensorFlow is 2.2, I am trying to run SynthSeg in a more recent version of TensorFlow (constrained due to nobrainer) and it throws the following error. Admittedly, this is the case with almost every version from 2.3 to 2.15. 2.16 and above throws a different error which is for a different day.
Would you be kind enough to spare some time and help me fix this bug? I am happy to do any groundwork (environment setup, colab if you prefer) to make it easy for you. I'll be looking forward to your thoughts.
Note: The following error was thrown when running SynthSeg/scripts/tutorials/1-generation_visualisation.py
in TF=2.15
Traceback (most recent call last):
File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/tutorials/1-generation_visualisation.py", line 28, in <module>
im, lab = brain_generator.generate_brain()
File "/om2/user/hgazula/SynthSeg/SynthSeg/brain_generator.py", line 324, in generate_brain
(image, labels) = next(self.brain_generator)
File "/om2/user/hgazula/SynthSeg/SynthSeg/brain_generator.py", line 319, in _build_brain_generator
[image, labels] = self.labels_to_image_model.predict(model_inputs)
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_file9w0hokd_.py", line 15, in tf__predict_function
retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
File "/tmp/__autograph_generated_filec3y9yjkg.py", line 38, in tf__call
ag__.if_stmt(ag__.not_(ag__.ld(self).add_batchsize), if_body, else_body, get_state, set_state, ('mask', 'self.min_res_tens', 'shape'), 3)
File "/tmp/__autograph_generated_filec3y9yjkg.py", line 28, in else_body
ag__.ld(self).min_res_tens = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).expand_dims, (ag__.ld(self).min_res_tens, 0), None, fscope), ag__.ld(tile_shape)), None, fscope)
ValueError: in user code:
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2440, in predict_function *
return step_function(self, iterator)
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2425, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2413, in run_step **
outputs = model.predict_step(data)
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2381, in predict_step
return self(x, training=False)
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_filec3y9yjkg.py", line 38, in tf__call
ag__.if_stmt(ag__.not_(ag__.ld(self).add_batchsize), if_body, else_body, get_state, set_state, ('mask', 'self.min_res_tens', 'shape'), 3)
File "/tmp/__autograph_generated_filec3y9yjkg.py", line 28, in else_body
ag__.ld(self).min_res_tens = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).expand_dims, (ag__.ld(self).min_res_tens, 0), None, fscope), ag__.ld(tile_shape)), None, fscope)
ValueError: Exception encountered when calling layer 'sample_resolution' (type SampleResolution).
in user code:
File "/om2/user/hgazula/SynthSeg/ext/lab2im/layers.py", line 608, in call *
self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape)
ValueError: Shape must be rank 3 but is rank 2 for '{{node model/sample_resolution/Tile}} = Tile[T=DT_FLOAT, Tmultiples=DT_INT32](model/sample_resolution/ExpandDims, model/sample_resolution/concat)' with input shapes: [1,?,3], [2].
Call arguments received by layer 'sample_resolution' (type SampleResolution):
• inputs=tf.Tensor(shape=(None, 54, 1), dtype=float32)
• kwargs={'training': 'False'}
While I managed to go past this step by modifying the build
and call
in SampleResolution
by replacing
if input_shape:
self.add_batchsize = True
self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32')
in build(...)
with
self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32')
if input_shape:
self.add_batchsize = True
self.min_res_tens = tf.expand_dims(self.min_res_tens, 0)
and subsequently replacing self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape)
in call(...)
with self.min_res_tens = tf.tile(self.min_res_tens, tile_shape)
, it now throws a new error that goes along the lines of
Traceback (most recent call last):
File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/misc/synth_202408.py", line 291, in <module>
[image, labels] = lab_to_im_model.predict(model_inputs)
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
TypeError: <tf.Tensor 'sample_resolution/Tile:0' shape=(None, 3) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
<tf.Tensor 'sample_resolution/Tile:0' shape=(None, 3) dtype=float32> was defined here:
File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/misc/synth_202408.py", line 247, in <module>
File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/misc/synth_202408.py", line 137, in labels_to_image_model
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 1063, in __call__
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 2593, in _functional_construction_call
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 2439, in _keras_tensor_symbolic_call
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 2498, in _infer_output_signature
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler
File "/om2/user/hgazula/SynthSeg/ext/lab2im/layers.py", line 715, in call
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1217, in if_stmt
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1270, in _py_if_stmt
File "/om2/user/hgazula/SynthSeg/ext/lab2im/layers.py", line 733, in call
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/ops/gen_array_ops.py", line 12045, in tile
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2652, in _create_op_internal
File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1160, in from_node_def
The tensor <tf.Tensor 'sample_resolution/Tile:0' shape=(None, 3) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=sample_resolution_scratch_graph, id=47607497100160), which is out of scope.
I was wondering if you happen to have any ideas. I will also reach out to other Synth users to see if anyone's using the latest version of TF for their work.
Fixed. Thank you. See neuronets/nobrainer#338 (comment) for resolution. FYI, It is backward compatible as well.