keras-team/tf-keras

Fail to load model with == op (operator.eq)

Opened this issue · 3 comments

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): NO
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Colab
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.13 / 2.15.0-dev20230903

Describe the problem.

When the model continues build-in equal like == or operator.eq operator fails in load
but when using tf.math.equal is ok

Describe the current behavior.

load when have build-in operator like other build-in + - * /

  • Do you want to contribute a PR? (yes/no): no

Standalone code to reproduce the issue.

gist here

Source code / logs.

TypeError                                 Traceback (most recent call last)
[<ipython-input-5-db50d76d7b09>](https://localhost:8080/#) in <cell line: 2>()
      1 # fails load model
----> 2 create_model_save_predict_load(operator.eq)

9 frames
[<ipython-input-3-c5cb0ac5e529>](https://localhost:8080/#) in create_model_save_predict_load(eq_func)
     12     model.predict([data1, data2])
     13 
---> 14     tf.keras.models.load_model("model.keras")  # fails here when is ==  or operator.eq

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    252                 f"with the native Keras format: {list(kwargs.keys())}"
    253             )
--> 254         return saving_lib.load_model(
    255             filepath,
    256             custom_objects=custom_objects,

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    279 
    280     except Exception as e:
--> 281         raise e
    282     else:
    283         return model

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py](https://localhost:8080/#) in load_model(filepath, custom_objects, compile, safe_mode)
    244             # Construct the model from the configuration file in the archive.
    245             with ObjectSharingScope():
--> 246                 model = deserialize_keras_object(
    247                     config_dict, custom_objects, safe_mode=safe_mode
    248                 )

[/usr/local/lib/python3.10/dist-packages/keras/src/saving/serialization_lib.py](https://localhost:8080/#) in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    726     safe_mode_scope = SafeModeScope(safe_mode)
    727     with custom_obj_scope, safe_mode_scope:
--> 728         instance = cls.from_config(inner_config)
    729         build_config = config.get("build_config", None)
    730         if build_config:

[/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py](https://localhost:8080/#) in from_config(cls, config, custom_objects)
   3328                 # Revive Functional model
   3329                 # (but not Functional subclasses with a custom __init__)
-> 3330                 inputs, outputs, layers = functional.reconstruct_from_config(
   3331                     config, custom_objects
   3332                 )

[/usr/local/lib/python3.10/dist-packages/keras/src/engine/functional.py](https://localhost:8080/#) in reconstruct_from_config(config, custom_objects, created_layers)
   1503                 while layer_nodes:
   1504                     node_data = layer_nodes[0]
-> 1505                     if process_node(layer, node_data):
   1506                         layer_nodes.pop(0)
   1507                     else:

[/usr/local/lib/python3.10/dist-packages/keras/src/engine/functional.py](https://localhost:8080/#) in process_node(layer, node_data)
   1443                     input_tensors
   1444                 )
-> 1445             output_tensors = layer(input_tensors, **kwargs)
   1446 
   1447             # Update node index map.

[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/dispatch.py](https://localhost:8080/#) in op_dispatch_handler(*args, **kwargs)
   1252         if iterable_params is not None:
   1253           args, kwargs = replace_iterable_params(args, kwargs, iterable_params)
-> 1254         result = api_dispatcher.Dispatch(args, kwargs)
   1255         if result is not NotImplemented:
   1256           return result

TypeError: Missing required positional argument

@Chizkiyahu,
Thank you for the request. Could you please provide that there is any specific reason to use operator.eq operator. From the code we can see that you are trying to use import operator, which we are supporting. Please provide more information on the use-case. Thank you!

gist here

I understand that certain TensorFlow math operations are equivalent to Python's built-in operators and functions from the operator module. For instance:

  • tf.math.add is equivalent to the + operator and operator.add.
  • tf.math.subtract is equivalent to the - operator and operator.sub.
  • tf.math.multiply is equivalent to the * operator and operator.mul.

to be clear == and operator.eq is the same

Similarly, I expected that tf.math.equal would be equivalent to the == operator and operator.eq. While this seems to work as expected, I encounter issues when I save and reload the model from disk.

@tilakrayal, I hope I've captured your points accurately. If you need more information or find any part unclear, please feel free to ask.

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.