philipperemy/keras-tcn

use_weight_norm=True -> Model fails to load when using *.tf model format.

Luux opened this issue · 2 comments

Luux commented

Describe the bug
This one took me a while to isolate. If you set use_weight_norm to True and save the model not in the h5 format as in the loading example, but in the new tf-format, using weight normalization results in tensorflow not being able to load the model again.

Paste a snippet

import tensorflow as tf
import numpy as np
from tcn import TCN

model_fn = "test.tf"  # or simply "test" to use the directory format

inputs = tf.keras.layers.Input(shape=(max_features, max_len))
tcn = TCN(
    nb_filters=64,
    return_sequences=True,
    use_skip_connections=True,
    padding="same",
    dilations=[1, 2, 4, 8, 16, 32, 64],
    use_weight_norm=True,
    name="tcn1",
)(inputs)
dense = tf.keras.layers.Dense(units=1, activation="sigmoid")(tcn)
model = tf.keras.Model(inputs=inputs, outputs=dense)

model.compile(optimizer="adam")

print("Save model...")
model.save(model_fn)

print("Load model...")
model = tf.keras.models.load_model(
    model_fn, compile=False
)

model.compile(optimizer="adam")

model.summary()

inputs = np.ones(shape=(1, 64, 128))
out = model.predict(inputs)
print(out)

This results in errors such as

Traceback (most recent call last):
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/luux/.vscode/extensions/ms-python.python-2021.3.658691958/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
    cli.main()
  File "/home/luux/.vscode/extensions/ms-python.python-2021.3.658691958/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
    run()
  File "/home/luux/.vscode/extensions/ms-python.python-2021.3.658691958/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/runpy.py", line 265, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/luux/models/aisign/load_TCN.py", line 50, in <module>
    model = tf.keras.models.load_model(
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 187, in load_model
    return saved_model_load.load(filepath, compile, options)
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 120, in load
    model = tf_load.load_internal(
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 632, in load_internal
    loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 194, in __init__
    super(KerasObjectLoader, self).__init__(*args, **kwargs)
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 130, in __init__
    self._load_all()
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 218, in _load_all
    super(KerasObjectLoader, self)._load_all()
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 141, in _load_all
    self._load_nodes()
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 283, in _load_nodes
    node, setter = self._recreate(proto, node_id)
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 239, in _recreate
    obj, setter = super(KerasObjectLoader, self)._recreate(proto, node_id)
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 393, in _recreate
    return factory[kind]()
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 382, in <lambda>
    "function": lambda: self._recreate_function(proto.function),
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 420, in _recreate_function
    return function_deserialization.recreate_function(
  File "/home/luux/miniconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 261, in recreate_function
    concrete_function_objects.append(concrete_functions[concrete_function_name])
KeyError: '__inference_conv1D_0_layer_call_fn_26401'

I also tried to specify the TCN as custom object just as usually done if you use h5, that didn't work either. use_batch_norm and use_layer_norm work as expected out-of-the-box.

Dependencies
tensorflow==2.3 (I also tested 2.4 without success)

@Luux tough one. What I can say right now is that the weight_norm depends on something in tf addons. So there's definitely more than just tf involved here. All the other norms are in pure tf.

TL;DR: Upgrade your tensorflow version

Okay so this was fixed in the recent versions of Tensorflow.

I can say that this code runs fine on tensorflow 2.7.

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model

from tcn import TCN

model_fn = "test.tf"  # or simply "test" to use the directory format

max_features = 10
max_len = 5

inputs = tf.keras.layers.Input(shape=(max_features, max_len))
tcn = TCN(
    nb_filters=64,
    return_sequences=True,
    use_skip_connections=True,
    padding="same",
    dilations=[1, 2, 4, 8, 16, 32, 64],
    use_weight_norm=True,
    name="tcn1",
)(inputs)
dense = tf.keras.layers.Dense(units=1, activation="sigmoid")(tcn)
model = tf.keras.Model(inputs=inputs, outputs=dense)

model.compile(optimizer="adam")

print("Save model...")
model.save(model_fn)

print("Load model...")
model = load_model(
    filepath=model_fn,
    compile=False,
)

model.compile(optimizer="adam")

model.summary()

inputs = np.ones(shape=(1, max_features, max_len))
out = model.predict(inputs)
print(out)