Possible bug in use_skip_connections
Closed this issue · 4 comments
Hey folks,
I've been playing around a bit with the TCN model class and just wanted to check with you whether there is an issue in the way that the skip connections are included.
The residual block returns two outputs:
Lines 151 to 156 in 2483cd9
Lines 157 to 164 in 2483cd9
These are respectively the channel-matched input + the output of the convolutional layers (the typical output of block with a skip connection), and the output of the convolutional layers alone. In my mind then this first output (input + "residual") is what the network should use when use_skip_connections
is flagged on, and the second is what should be used when it's flagged off.
Later on when the TCN outputs are built from the residual blocks, these outputs are assigned to the variables x
and skip_out
respectively, and I think this may be the wrong way round?
Lines 316 to 320 in 2483cd9
Lines 323 to 324 in 2483cd9
Specifically, when use_skip_connections
is on, the TCN outputs the sum of those second outputs (somewhat confusingly named skip_out
) - the conv block outputs themselves, without adding the inputs back.
I think the error is in the documentation. If i understood the code correctly, the skip_out connections connect the output of each residual block directly to the output of the TCN. In the documentation it says the skip_connections parameter specifies whether there should be a skip connection from the input to each residual block.
The skip connections inside the residual block are not affected by the parameter at all, if I read the code correctly.
I think it is an error, the res_act_x in residual block -->x in TCN and x in residual block --> skip_out in TCN. self.skip_connections save the x in residual block.
Hey guys,
I will update the README.
Here is the structure of one residual block (using tensorboard):
tensorboard --logdir logs
And I added:
from tensorflow.keras.callbacks import Callback
tensorboard = TensorBoard(
log_dir='/tmp/logs',
histogram_freq=1,
write_images=True
)
As a callback in the .fit()
function
The skip connections connect the output of each dilated conv stack (and not the residual) of all the residual blocks together.
It can be visualized here:
- From the wavenet paper: https://arxiv.org/pdf/1609.03499.pdf
By dilated conv stack I mean this stack:
- From the empirical TCN paper: https://arxiv.org/pdf/1803.01271.pdf
I'll this issue now. But feel free to re-open it. I pushed a new version 3.4.1 reflecting the updates (mostly renaming variables and README).