Some confusion for case 'stride = 2'
shuizidesu opened this issue · 0 comments
shuizidesu commented
Dear authors:
The forward procedure for
$$ \tilde{x}{j+1} = x{j} + F_{j+1} \tilde{x}_{j} $$
However, the code for case 'stride = 2' leads to the following form:
class irevnet_block(nn.Module):
...
def forward(self, x):
""" bijective or injective block forward """
if self.pad != 0 and self.stride == 1:
x = merge(x[0], x[1])
x = self.inj_pad.forward(x)
x1, x2 = split(x)
x = (x1, x2)
x1 = x[0]
x2 = x[1]
Fx2 = self.bottleneck_block(x2)
if self.stride == 2:
x1 = self.psi.forward(x1)
x2 = self.psi.forward(x2)
y1 = Fx2 + x1
return (x2, y1)
which means
$$ \tilde{x}{j+1} = {S}{j+1}x_{j} + F_{j+1} \tilde{x}_{j} $$
Whether I understand correctly? It is appreciated that answering my question in your busy time.