An alternative reproduce of inplace_abn but not work
blueardour opened this issue · 0 comments
blueardour commented
Hi, everyone:
I tried to re-implement the inplace-ABN based on my understanding of the concept.
Out of my expectation, the code below cost more GPU memory rather than reduce the footprint. Mightbe my thought is too native or my lack knowledge of Pytorch.
Could anyone point out any mistake I made? (I talked with colleagues but did not receive advices)
import torch
class custom_norm_relu(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, running_var, running_mean, eps):
scale = weight * (running_var + eps).rsqrt()
if bias is not None:
bias = bias - running_mean * scale
else:
bias = - running_mean * scale
scale = scale.reshape(1, -1, 1, 1).detach()
bias = bias.reshape(1, -1, 1, 1).detach()
weight = weight.reshape(1, -1, 1, 1).detach()
result = input * scale + bias
select = result < 0.0
result.masked_fill_(select, 0.0) # ReLU
ctx.save_for_backward(result, weight, bias, scale, select)
#ctx.mark_dirty(result)
return result # input of next conv layer
@staticmethod
def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = None, None, None
ouput, weight, bias, scale, select, = ctx.saved_tensors
grad_output.masked_fill_(select, 0.0)
if ctx.needs_input_grad[0]:
grad_input = grad_output * scale
if ctx.needs_input_grad[1]:
grad_weight = grad_output * (ouput - bias).div(weight)
grad_weight = grad_weight.sum(dim=[0,2,3])
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(dim=[0,2,3])
return grad_input, grad_weight, grad_bias, None, None, None
class BatchNorm2d_ReLU(torch.nn.Module):
"""
BatchNorm2d and ReLU in one Module
"""
def __init__(self, num_features, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(num_features), requires_grad=True)
self.bias = torch.nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features) - eps)
self.fn = custom_norm_relu.apply
def forward(self, x):
# compute self.running_mean and self.running_var here, omit temporarily
return self.fn(x, self.weight, self.bias, self.running_var, self.running_mean, self.eps)
The select
variable introduces extra memory but result
variable might be resued in next conv layer. As the memory consumption of the former is smaller than the latter one, overall memory should be reduced based on my perspective.