Keras 3 support
Martin-Molinero opened this issue · 4 comments
Martin-Molinero commented
Running the following fails with keras 3.0.5
from tcn import TCN, tcn_full_summary
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
# if time_steps > tcn_layer.receptive_field, then we should not
# be able to solve this task.
batch_size, time_steps, input_dim = None, 20, 1
def get_x_y(size=1000):
import numpy as np
pos_indices = np.random.choice(size, size=int(size // 2), replace=False)
x_train = np.zeros(shape=(size, time_steps, 1))
y_train = np.zeros(shape=(size, 1))
x_train[pos_indices, 0] = 1.0 # we introduce the target in the first timestep of the sequence.
y_train[pos_indices, 0] = 1.0 # the task is to see if the TCN can go back in time to find it.
return x_train, y_train
tcn_layer = TCN(input_shape=(time_steps, input_dim))
# The receptive field tells you how far the model can see in terms of timesteps.
print('Receptive field size =', tcn_layer.receptive_field)
m = Sequential([
tcn_layer,
Dense(1)
])
m.compile(optimizer='adam', loss='mse')
tcn_full_summary(m, expand_residual_blocks=False)
x, y = get_x_y()
m.fit(x, y, epochs=10, validation_split=0.2)
Error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/models/sequential.py", line 71, in __init__
self._maybe_rebuild()
File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/models/sequential.py", line 136, in _maybe_rebuild
self.build(input_shape)
File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/layers/layer.py", line 224, in build_wrapper
original_build_method(*args, **kwargs)
File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/models/sequential.py", line 177, in build
x = layer(x)
^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/opt/miniconda3/lib/python3.11/site-packages/tcn/tcn.py", line 316, in build
self.slicer_layer.build(self.build_output_shape.as_list())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'as_list'
Kurdakov commented
replacing build_output_shape.as_list() to list(build_output_shape) where as_list is used works as a fix, no other issues with keras 3 were observed after the change
latexalpha commented
replacing build_output_shape.as_list() to list(build_output_shape) where as_list is used works as a fix, no other issues with keras 3 were observed after the change
The modification solved my problem.