philipperemy/keras-tcn

Keras 3 support

Martin-Molinero opened this issue · 4 comments

Running the following fails with keras 3.0.5

from tcn import TCN, tcn_full_summary
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

# if time_steps > tcn_layer.receptive_field, then we should not
# be able to solve this task.
batch_size, time_steps, input_dim = None, 20, 1


def get_x_y(size=1000):
    import numpy as np
    pos_indices = np.random.choice(size, size=int(size // 2), replace=False)
    x_train = np.zeros(shape=(size, time_steps, 1))
    y_train = np.zeros(shape=(size, 1))
    x_train[pos_indices, 0] = 1.0  # we introduce the target in the first timestep of the sequence.
    y_train[pos_indices, 0] = 1.0  # the task is to see if the TCN can go back in time to find it.
    return x_train, y_train


tcn_layer = TCN(input_shape=(time_steps, input_dim))
# The receptive field tells you how far the model can see in terms of timesteps.
print('Receptive field size =', tcn_layer.receptive_field)

m = Sequential([
    tcn_layer,
    Dense(1)
])

m.compile(optimizer='adam', loss='mse')

tcn_full_summary(m, expand_residual_blocks=False)

x, y = get_x_y()
m.fit(x, y, epochs=10, validation_split=0.2)

Error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/models/sequential.py", line 71, in __init__
    self._maybe_rebuild()
  File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/models/sequential.py", line 136, in _maybe_rebuild
    self.build(input_shape)
  File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/layers/layer.py", line 224, in build_wrapper
    original_build_method(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/models/sequential.py", line 177, in build
    x = layer(x)
        ^^^^^^^^
  File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/opt/miniconda3/lib/python3.11/site-packages/tcn/tcn.py", line 316, in build
    self.slicer_layer.build(self.build_output_shape.as_list())
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'as_list'

replacing build_output_shape.as_list() to list(build_output_shape) where as_list is used works as a fix, no other issues with keras 3 were observed after the change

replacing build_output_shape.as_list() to list(build_output_shape) where as_list is used works as a fix, no other issues with keras 3 were observed after the change

The modification solved my problem.