f90/Wave-U-Net-Pytorch

Depth other than 1 does not work

B-lanc opened this issue · 0 comments

Changing the depth to anything other than 1 will result in this error

RuntimeError: Given groups=1, weight of size [512, 512, 5], expected input[1, 1024, 357] to have 512 channels, but got 1024 channels instead

I am pretty certain this is the cause of said error

[ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])

and
combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1))

line 24 implies that the next modules will not take the shortcut (n_outputs as opposed to n_outputs+n_shortcut)
while line 38 adds the shortcut after the first iteration.

It should be either all post shortcut convs to take the added shortcut, or the shortcut is only added once. I think the latter makes more sense, though.