
ReZero implementation?

tbachlechner opened this issue · 1 comments

ReZero is implemented as the residual connection x = x + alpha * F(x). It seems currently you are doing something else instead: x = alpha * x + F(x). How about changing the forward function as follows

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.ReZero:
            out = self.alpha_i * out + self.shortcut(x)
            out = out + self.shortcut(x)
        out = F.relu(out)
        return out

Absolutely, I have just changed it as per your suggestion. Thank you for pointing it out!