[BUG] Chapter 10: Syntax for saving/restoring keras model has changed
thomas-haslwanter opened this issue · 0 comments
thomas-haslwanter commented
Thanks for helping us improve this project!
Describe the bug
The syntax for saving/restoring keras models does not work as show in 10_neural_nets_with_keras.ipynb
To Reproduce
Run the IPYNB up to cell 83 ('Saving and restoring a model')
tf.keras.models.load_model("my_keras_model.keras")
now seems to be required (no more
model.save("my_keras_model", save_format="tf")
And
model = tf.keras.models.load_model("my_keras_model.keras")
leads to the following error message:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[86], line 1
----> 1 model = tf.keras.models.load_model("my_keras_model.keras")
2 #y_pred_main, y_pred_aux = model.predict((X_new_wide, X_new_deep))
File C:\Programs\WPy64-31220\python-3.12.2.amd64\Lib\site-packages\keras\src\saving\saving_api.py:176, in load_model(filepath, custom_objects, compile, safe_mode)
173 is_keras_zip = True
175 if is_keras_zip:
--> 176 return saving_lib.load_model(
177 filepath,
178 custom_objects=custom_objects,
179 compile=compile,
180 safe_mode=safe_mode,
181 )
182 if str(filepath).endswith((".h5", ".hdf5")):
183 return legacy_h5_format.load_model_from_hdf5(
184 filepath, custom_objects=custom_objects, compile=compile
185 )
File C:\Programs\WPy64-31220\python-3.12.2.amd64\Lib\site-packages\keras\src\saving\saving_lib.py:152, in load_model(filepath, custom_objects, compile, safe_mode)
147 raise ValueError(
148 "Invalid filename: expected a `.keras` extension. "
149 f"Received: filepath={filepath}"
150 )
151 with open(filepath, "rb") as f:
--> 152 return _load_model_from_fileobj(
153 f, custom_objects, compile, safe_mode
154 )
File C:\Programs\WPy64-31220\python-3.12.2.amd64\Lib\site-packages\keras\src\saving\saving_lib.py:170, in _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode)
168 # Construct the model from the configuration file in the archive.
169 with ObjectSharingScope():
--> 170 model = deserialize_keras_object(
171 config_dict, custom_objects, safe_mode=safe_mode
172 )
174 all_filenames = zf.namelist()
175 if _VARS_FNAME + ".h5" in all_filenames:
File C:\Programs\WPy64-31220\python-3.12.2.amd64\Lib\site-packages\keras\src\saving\serialization_lib.py:694, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
691 if obj is not None:
692 return obj
--> 694 cls = _retrieve_class_or_fn(
695 class_name,
696 registered_name,
697 module,
698 obj_type="class",
699 full_config=config,
700 custom_objects=custom_objects,
701 )
703 if isinstance(cls, types.FunctionType):
704 return cls
File C:\Programs\WPy64-31220\python-3.12.2.amd64\Lib\site-packages\keras\src\saving\serialization_lib.py:812, in _retrieve_class_or_fn(name, registered_name, module, obj_type, full_config, custom_objects)
809 if obj is not None:
810 return obj
--> 812 raise TypeError(
813 f"Could not locate {obj_type} '{name}'. "
814 "Make sure custom classes are decorated with "
815 "`@keras.saving.register_keras_serializable()`. "
816 f"Full object config: {full_config}"
817 )
TypeError: Could not locate class 'WideAndDeepModel'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'module': None, 'class_name': 'WideAndDeepModel', 'config': {'name': 'my_cool_model', 'trainable': True, 'dtype': 'float32'}, 'registered_name': 'WideAndDeepModel', 'build_config': {'input_shape': [[None, 5], [None, 6]]}, 'compile_config': {'optimizer': {'module': 'keras.optimizers', 'class_name': 'Adam', 'config': {'name': 'adam', 'learning_rate': 0.0010000000474974513, 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'loss_scale_factor': None, 'gradient_accumulation_steps': None, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}, 'registered_name': None}, 'loss': 'mse', 'loss_weights': [0.9, 0.1], 'metrics': ['RootMeanSquaredError', 'RootMeanSquaredError'], 'weighted_metrics': None, 'run_eagerly': False, 'steps_per_execution': 1, 'jit_compile': False}}
See the comments on
https://www.tensorflow.org/api_docs/python/tf/keras/utils/register_keras_serializable
Versions (please complete the following information):
- OS: Win 11
- Python: 3.12.2.
- TensorFlow: 2.16.1
Additional context
Add any other context about the problem here.