InPlaceABN backward is total different from nn.BatchNorm2d
Closed this issue · 2 comments
I do the following test, given the same input, outputs of nn.BatchNorm2d and InPlaceABN is the same, but the backward is totally different,I want to konw if there's a bug here or the way I test is wrong.
import torch
import torch.nn as nn
from inplace_abn import InPlaceABN
import torch.nn.functional as F
torch.manual_seed(1234)
t = torch.randn(2,5,40,6).cuda()
t1 = t.detach().clone().requires_grad_(True)
t2 = t.detach().clone().requires_grad_(True)
bn1 = nn.BatchNorm2d(5, eps=1e-9).cuda()
bn2 = InPlaceABN(5, eps=1e-9).cuda()
t11 = t1*2
out1 = F.leaky_relu(bn1(t1), 0.01)
out1.sum().backward()
grad1 = t1.grad
t22 = t2*2
out2=bn2(t22)
out22 = out2.clone()
out2.sum().backward()
grad2 = t2.grad
print(grad1)
print(grad2)
print((out1-out22).abs().sum())
print((grad1-grad2).abs().sum())
@xiaodao2049 thank you for your bug report, this helped us discover an issue with the way we compute the backward pass, which has a nasty interaction with some other Pytorch operations. While in many cases the backward is computed correctly, when InPlaceABN is directly followed by sum()
the backward computation breaks. You can easily verify this by e.g. replacing out{1,2}.sum().backward()
with (out{1,2} * 2).sum().backward()
in your code, and you will see that the results will then be identical as expected. Luckly, the common case where InPlaceABN is followed by other layers such as Conv2d
or Linear
was also working correctly.
We issued a work-around fix to this issue in v1.0.9
, and are working on a better fix for a future release.
@xiaodao2049 thank you for your bug report, this helped us discover an issue with the way we compute the backward pass, which has a nasty interaction with some other Pytorch operations. While in many cases the backward is computed correctly, when InPlaceABN is directly followed by
sum()
the backward computation breaks.
Can you please explain what happens in the backward computation when there is output in-place operation in InPlaceABN backward and InPlaceABN is directly followed by sum()
?