AmeenAli/XAI_Transformers

Relevance values explosion when applying the method on Vision Transformer

zmy1116 opened this issue · 0 comments

Hello,

First thank you for this interesting work!

I'm trying to test your algorithm on vision transformer. However, I encounter the "relevance explosion" problem:
when relevance is distributed from one block to the previous, the relevance total scale just jump by more than 10x... and after 12 blocks... the values are in scale of 10^10 +

I rewrite the LayerNorm and Linear layer similar to the way you had:

For linear: I wrote the gamma rule with gamma = 0.02

    def alternative_inference(self, input):

        if 'player' not in self.__dict__:
            out_size, in_size = self.weight.shape
            player = Linear(in_size, out_size)
            player.weight = torch.nn.Parameter(self.weight + 0.02 * self.weight.clamp(min=0))
            player.bias = torch.nn.Parameter(self.bias + 0.02 * self.bias.clamp(min=0))
            self.player = player

        z = self(input)
        zp = self.player(input)
        return zp * (z / zp).data

For layernorm, I only detached the std

    def alternative_inference(self, input):
        mean = input.mean(dim=-1, keepdim=True)
        std = input.std(dim=-1, keepdim=True)
        std = std.detach()
        input_norm = (input - mean) / (std + self.eps)

        input_norm = input_norm * self.weight + self.bias
        return input_norm

And all relevance backpropagation is done via gradient* input, the same way as you did...

this problem is also described in chefer et al. 2021 (transformer explanability beyond visualization). they explicitly handle the problem by forcing a normalizaiton at every add layer and only using LRP alpha-beta rule with alpha = 1 and beta= 0 ...

In your paper, you described results on distilledBert from huggingface, so your work should be able to run on full scale transformer (12 blocks with all the bells and whistles).. I'm wondering if you have to apply other tricks to get it working...

Thanks