HomebrewNLP/revlib

convolution network, striding

Closed this issue · 1 comments

If I need to stride the input, I'm wondering how you can recommend how to use it.

class Stride(nn.Module):
    def forward(self, input):
        return input[[slice(None)]*2 + [slice(0, None, 2) for _ in range(2, input.ndim)]]
Stride()(torch.randn(1, 1, 11, 16)).shape # 1, 1, 6, 8

You might be interested in #5, where we discussed the same issue but across the features instead of pixels. To summarize a long thread: It's impossible to have a different number of output elements than input elements. This is not an issue with RevLib but a design choice from its backbone, RevNet. The original paper worked around this issue by using multiple reversible sequences instead of one, with AveragePooling between them.
Suppose you're set on using strides within the same reversible sequence, for example, to avoid the memory overhead you'd get from running multiple ReverersibleSequences. In that case, you must wrap your function in downsampling and upsampling blocks.