materialsvirtuallab/megnet

predict_structure triggering retracing

Ash-Pera opened this issue · 2 comments

When doing trying to do predictions, I get a lot of warnings about function retracing. Is this a warning that should be ignored, or should I be dong something different to prevent this?

def predict(pred_model, structure):
    return pred_model.predict_structure(structure)
predict_verify = functools.partial(predict, model)
test_results = test_data.join(test_data['structure'].apply(predict_verify))

WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 
0x7f273dab04c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to 
passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument 
shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization
/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

@Ash-Pera This is introduced by tensorflow2, our codes were primarily developed using tensorflow 1.x

A quick fix for this would be putting

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

at the beginning of your script.

@Ash-Pera Hi, I just checked that tensorflow 2.3.0 fixed this issue. Please update your tensorflow to 2.3.0.

Closing it now.