keras-team/keras

TypeError: Could not locate class 'adam'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`

mpetteno opened this issue · 1 comments

Hi everyone,

I think there is a problem with the loading of a model that has been compiled providing the optimizer as a dict.
This does not happen if optimizer="adam" or optimizer=keras.optimizers.Adam()

Here the code that reproduces the issue:

model = keras.models.Sequential()
model.add(keras.layers.Dense(64, input_dim=3, activation='relu'))
model.add(keras.layers.Dense(32, activation='relu'))
model.add(keras.layers.Dense(1, activation='linear'))

model.compile(
    optimizer={
      "class_name": "Adam",
      "config": {
        "learning_rate": 0.01,
        "beta_1": 0.9,
        "beta_2": 0.999,
        "epsilon": 1e-7
      }
    },
    loss='mse'
)

keras.saving.save_model(model, 'model.keras')
loaded_model = keras.saving.load_model("model.keras")

The full traceback is:

Traceback (most recent call last):
File "/venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 152, in load_model
return _load_model_from_fileobj(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 170, in _load_model_from_fileobj
model = deserialize_keras_object(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 734, in deserialize_keras_object
instance.compile_from_config(compile_config)
File "/venv/lib/python3.11/site-packages/keras/src/trainers/trainer.py", line 870, in compile_from_config
config = serialization_lib.deserialize_keras_object(config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 594, in deserialize_keras_object
return {
^
File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 595, in
key: deserialize_keras_object(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 694, in deserialize_keras_object
cls = _retrieve_class_or_fn(
^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 812, in _retrieve_class_or_fn
raise TypeError(
TypeError: Could not locate class 'adam'. Make sure custom classes are decorated with @keras.saving.register_keras_serializable(). Full object config: {'class_name': 'adam', 'config': {'learning_rate': 0.01, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07}}

I think that this is due to the fact that in this case the class_name field is serialized as "adam" (lowe case) and not "Adam" (capitalized) and thus in serialization_lib.py at line 803 obj is no resolved.

Thanks for your help.

Thanks for the report, this is fixed at HEAD. Note that passing optimizers as dicts isn't an officially supported API (officially supported APIs are to pass it as a string or as an Optimizer instance).