mapillary/inplace_abn

An alternative reproduce of inplace_abn but not work

blueardour opened this issue · 0 comments

Hi, everyone:

I tried to re-implement the inplace-ABN based on my understanding of the concept.

Out of my expectation, the code below cost more GPU memory rather than reduce the footprint. Mightbe my thought is too native or my lack knowledge of Pytorch.

Could anyone point out any mistake I made? (I talked with colleagues but did not receive advices)

import torch

class custom_norm_relu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias, running_var, running_mean, eps):
        scale = weight * (running_var + eps).rsqrt()
        if bias is not None:
            bias = bias - running_mean * scale
        else:
            bias = - running_mean * scale
        scale = scale.reshape(1, -1, 1, 1).detach()
        bias = bias.reshape(1, -1, 1, 1).detach()
        weight = weight.reshape(1, -1, 1, 1).detach()
        result = input * scale + bias

        select = result < 0.0
        result.masked_fill_(select, 0.0) # ReLU

        ctx.save_for_backward(result, weight, bias, scale, select)
        #ctx.mark_dirty(result)
        return result # input of next conv layer

    @staticmethod
    def backward(ctx, grad_output):
        grad_input, grad_weight, grad_bias = None, None, None
        ouput, weight, bias, scale, select, = ctx.saved_tensors
        grad_output.masked_fill_(select, 0.0)

        if ctx.needs_input_grad[0]:
            grad_input = grad_output * scale

        if ctx.needs_input_grad[1]:
            grad_weight = grad_output * (ouput - bias).div(weight)
            grad_weight = grad_weight.sum(dim=[0,2,3])

        if ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(dim=[0,2,3])

        return grad_input, grad_weight, grad_bias, None, None, None

class BatchNorm2d_ReLU(torch.nn.Module):
    """
    BatchNorm2d and ReLU in one Module
    """
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(num_features), requires_grad=True)
        self.bias = torch.nn.Parameter(torch.zeros(num_features), requires_grad=True)
        self.register_buffer("running_mean", torch.zeros(num_features))
        self.register_buffer("running_var", torch.ones(num_features) - eps)
        self.fn = custom_norm_relu.apply

    def forward(self, x):
        # compute self.running_mean and self.running_var here, omit temporarily
        return self.fn(x, self.weight, self.bias, self.running_var, self.running_mean, self.eps)

The select variable introduces extra memory but result variable might be resued in next conv layer. As the memory consumption of the former is smaller than the latter one, overall memory should be reduced based on my perspective.