chr5tphr/zennit

Core: Second Order Gradients

chr5tphr opened this issue · 4 comments

I am currently working on supporting second-order gradients, i.e. gradients of the modified gradients, which is used for example to compute adversarial explanations.
The current issue which prevents second order gradients is that the gradient modification introduced by rules will also be applied when in the second-order backward pass.
This will prevented by disabling the modification temporarily when computing the second-order gradient, likely using something like a no_modification context for composites/attributors/rules.

As also pointed out in #125, handles are not stored for the backward hooks for tensors. Storing and removing the hooks before the second-order backward pass would correctly compute the modified-gradient derivatives, although then the same graph cannot be used to compute the modified gradient for a different gradient output. By adding the context, I am considering to also enable complete removal of the tensor backward hooks.

How far has the work progressed?
Actually, this is excactly want I would need for my research. It would be awesome if zennit had this functionality or at least a workaround for further research and testing.

Edit: I tried and failed to get the gradient of and attribution map (relevances).

Hey @HeinrichAD

thanks for the bump!

I have pushed a draft version in #159 and will try to get around finishing this up the following days.
I have not yet fully tested this, but if you are feeling brave, you can try it out:

$ pip install git+https://github.com/chr5tphr/zennit.git@second-order-gradients

Here's an example of how to compute the gradient (of a function of) the attribution wrt. the input:

Code Example
import os

import torch
from torchvision.models import vgg11
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
from PIL import Image

from zennit.composites import EpsilonGammaBox
from zennit.image import imgify


fname = 'dornbusch-lighthouse.jpg'

if not os.path.exists(fname):
    torch.hub.download_url_to_file(
        'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',
        fname,
    )

# define the base image transform
transform_img = Compose([
    Resize(256),
    CenterCrop(224),
])
# define the normalization transform
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
# define the full tensor transform
transform = Compose([
    transform_img,
    ToTensor(),
    transform_norm,
])

# load the image
image = Image.open('dornbusch-lighthouse.jpg')

# transform the PIL image and insert a batch-dimension
data = transform(image)[None]

model = vgg11(weights='DEFAULT')
composite = EpsilonGammaBox(low=-3., high=3.)

input = data.clone().requires_grad_(True)
target = torch.eye(1000)[[437]]
with composite.context(model) as modified_model:
    out = modified_model(input)
    relevance, = torch.autograd.grad(out, input, target, create_graph=True)
    # create a target heatmap, rolled 12 pixels south east
    target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
    loss = ((relevance - target_heat) ** 2).mean()
    # deactivate the rule hooks in order to leave the second order gradient untouched
    with composite.inactive():
        adv_grad, = torch.autograd.grad(loss, input)

imgify(relevance[0].detach().sum(0), cmap='coldnhot', symmetric=True).show()
imgify(target_heat[0].sum(0), cmap='coldnhot', symmetric=True).show()
imgify(adv_grad[0].sum(0), cmap='coldnhot', symmetric=True).show()

image
image
image

Thank you for the fast response and for your commitment.
I will test it, maybe not today, but tomorrow. I will let you know if I succeeded (or not).

I added my comments/feedback to the pull request itself.