ReZero implementation?
tbachlechner opened this issue · 1 comments
tbachlechner commented
ReZero is implemented as the residual connection x = x + alpha * F(x)
. It seems currently you are doing something else instead: x = alpha * x + F(x)
. How about changing the forward function as follows
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.ReZero:
out = self.alpha_i * out + self.shortcut(x)
else:
out = out + self.shortcut(x)
out = F.relu(out)
return out
fabio-deep commented
Absolutely, I have just changed it as per your suggestion. Thank you for pointing it out!