philipperemy/keras-tcn

Possible bug in use_skip_connections

Closed this issue · 4 comments

Hey folks,

I've been playing around a bit with the TCN model class and just wanted to check with you whether there is an issue in the way that the skip connections are included.

The residual block returns two outputs:

keras-tcn/tcn/tcn.py

Lines 151 to 156 in 2483cd9

x = inputs
self.layers_outputs = [x]
for layer in self.layers:
training_flag = 'training' in dict(inspect.signature(layer.call).parameters)
x = layer(x, training=training) if training_flag else layer(x)
self.layers_outputs.append(x)

keras-tcn/tcn/tcn.py

Lines 157 to 164 in 2483cd9

x2 = self.shape_match_conv(inputs)
self.layers_outputs.append(x2)
res_x = layers.add([x2, x])
self.layers_outputs.append(res_x)
res_act_x = self.final_activation(res_x)
self.layers_outputs.append(res_act_x)
return [res_act_x, x]

These are respectively the channel-matched input + the output of the convolutional layers (the typical output of block with a skip connection), and the output of the convolutional layers alone. In my mind then this first output (input + "residual") is what the network should use when use_skip_connections is flagged on, and the second is what should be used when it's flagged off.

Later on when the TCN outputs are built from the residual blocks, these outputs are assigned to the variables x and skip_out respectively, and I think this may be the wrong way round?

keras-tcn/tcn/tcn.py

Lines 316 to 320 in 2483cd9

try:
x, skip_out = layer(x, training=training)
except TypeError: # compatibility with tensorflow 1.x
x, skip_out = layer(K.cast(x, 'float32'), training=training)
self.skip_connections.append(skip_out)

keras-tcn/tcn/tcn.py

Lines 323 to 324 in 2483cd9

if self.use_skip_connections:
x = layers.add(self.skip_connections)

Specifically, when use_skip_connections is on, the TCN outputs the sum of those second outputs (somewhat confusingly named skip_out) - the conv block outputs themselves, without adding the inputs back.

I think the error is in the documentation. If i understood the code correctly, the skip_out connections connect the output of each residual block directly to the output of the TCN. In the documentation it says the skip_connections parameter specifies whether there should be a skip connection from the input to each residual block.

The skip connections inside the residual block are not affected by the parameter at all, if I read the code correctly.

I think it is an error, the res_act_x in residual block -->x in TCN and x in residual block --> skip_out in TCN. self.skip_connections save the x in residual block.

Hey guys,

I will update the README.

Here is the structure of one residual block (using tensorboard):

tensorboard --logdir logs

And I added:

from tensorflow.keras.callbacks import Callback
tensorboard = TensorBoard(
    log_dir='/tmp/logs',
    histogram_freq=1,
    write_images=True
)

As a callback in the .fit() function

image

The skip connections connect the output of each dilated conv stack (and not the residual) of all the residual blocks together.

It can be visualized here:

image

By dilated conv stack I mean this stack:

image

I'll this issue now. But feel free to re-open it. I pushed a new version 3.4.1 reflecting the updates (mostly renaming variables and README).