Smooth MaxPool2D rule
rachtibat opened this issue · 2 comments
Hey,
we'd like to add a new rule that smooths the MaxPool2D operation by replacing it by an AveragePool2D backward pass:
class SmoothMaxPool2dRule(BasicHook):
def __init__(self, epsilon=1e-6, zero_params=None):
stabilizer_fn = Stabilizer.ensure(epsilon)
super().__init__(
gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])),
reducer=(lambda inputs, gradients: inputs[0] * gradients[0]),
)
def backward(self, module, grad_input, grad_output):
'''Backward hook to compute LRP based on the class attributes.'''
original_input = self.stored_tensors['input'][0].clone()
inputs, outputs = [], []
kernel_size = module.kernel_size
stride = module.stride
padding = module.padding
input = original_input.requires_grad_()
with torch.autograd.enable_grad():
output = F.avg_pool2d(input, kernel_size, stride, padding, ceil_mode=False, count_include_pad=True, divisor_override=None)
inputs.append(input)
outputs.append(output)
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
gradients = torch.autograd.grad(
outputs,
inputs,
grad_outputs=grad_outputs,
create_graph=grad_output[0].requires_grad
)
relevance = self.reducer(inputs, gradients)
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
You can test the code with
import torch.nn as nn
from zennit.rules import *
from zennit.core import BasicHook
import torch.nn.functional as F
if __name__ == "__main__":
input = torch.linspace(0, 35, 36).view(1, 1, 6, 6).requires_grad_()
layer = nn.MaxPool2d(2, 2, 0)
norm_rule = Norm()
h = norm_rule.register(layer)
output = layer(input)
grad, = torch.autograd.grad(output, input, torch.ones_like(output))
h.remove()
print(input)
print(output)
print(grad)
print("###")
rule = SmoothMaxPool2dRule()
h = rule.register(layer)
output = layer(input)
grad, = torch.autograd.grad(output, input, torch.ones_like(output))
h.remove()
print(input)
print(output)
print(grad)
Do you think that's fine? I can create a pull request if you want.
Best,
Reduan
Hey Reduan,
thanks for the issue as always!
I think having a way to use the AvgPool2d gradient for MaxPool2d layers is a must-have.
I have some proof-of-concept code which I implemented back in the day to directly and explicitly compute the avg-pool gradient with MaxPool parameters using transposed convolutions.
While going over your code and seeing the BasicHook.backward
structure copied, I had the idea that we could also add a layer of abstraction above ParamMod
: a ModuleMod
or FuncMod
, which is a general modifier of the forward function.
This way, one could add very flexible custom rules based on BasicHook
, not only limited to the parameters of the module, which would be especially useful for parameter-less modules like MaxPool.
I have a different approach of attributing MaxPool in the pipeline, which could benefit from this approach. Do you maybe know of another use-case for arbitrary function override? Or maybe @sebastian-lapuschkin ?
If it is only for MaxPool, implementing an explicit rule based on Hook
may be better, where we could instead use my existing proof-of-concept code. Although, and I guess that's why you based this off BasicHook
rather than Hook
, stabilizer
would not automatically be part of the rule, which I think may not be necessary for pooling anyway.
As for the name, maybe its better to call it something like AvgPoolRule
, since for AvgPool
this would also be correct, although one could just use the EpsilonRule
there.
Hey,
thank you for your prompt and thoughtful response as always.
I like the idea to add a FuncMod
.
I ask Sebastian, and he told me that another use-case would be to change the 1x1 CNN downsample layer with stride=2 in ResNets that also creates such a checkerboard pattern.
See:
The question is, if we should implement it with a FuncMod
.
A spontaneous idea that would change the backward pass function instead:
- Compute Relevance normally. As a result, only every second or forth pixel would get relevance, the others are zero.
- Take only the relevance pixels and write them in a smaller image with 1/4 size.
- Average Upsample the image to the original size
With a FuncMod
we could do:
- Take the pixels that would be selected by the downsample layer and repeat them 4 times to the original input size by overwriting the ignored pixels.
- Do a 2x2 downsample with 4 times bigger kernel but 1/4th kernel values
Best