jhjacobsen/pytorch-i-revnet

Some confusion for case 'stride = 2'

shuizidesu opened this issue · 0 comments

Dear authors:

The forward procedure for $i$-RevNet described in the paper (Eq.(1)) is:

$$ \tilde{x}{j+1} = x{j} + F_{j+1} \tilde{x}_{j} $$

However, the code for case 'stride = 2' leads to the following form:

class irevnet_block(nn.Module):
...
    def forward(self, x):
        """ bijective or injective block forward """
        if self.pad != 0 and self.stride == 1:
            x = merge(x[0], x[1])
            x = self.inj_pad.forward(x)
            x1, x2 = split(x)
            x = (x1, x2)
        x1 = x[0]
        x2 = x[1]
        Fx2 = self.bottleneck_block(x2)
        if self.stride == 2:
            x1 = self.psi.forward(x1)
            x2 = self.psi.forward(x2)
        y1 = Fx2 + x1
        return (x2, y1)

which means

$$ \tilde{x}{j+1} = {S}{j+1}x_{j} + F_{j+1} \tilde{x}_{j} $$

Whether I understand correctly? It is appreciated that answering my question in your busy time.