Relevance values explosion when applying the method on Vision Transformer
zmy1116 opened this issue · 0 comments
Hello,
First thank you for this interesting work!
I'm trying to test your algorithm on vision transformer. However, I encounter the "relevance explosion" problem:
when relevance is distributed from one block to the previous, the relevance total scale just jump by more than 10x... and after 12 blocks... the values are in scale of 10^10 +
I rewrite the LayerNorm and Linear layer similar to the way you had:
For linear: I wrote the gamma rule with gamma = 0.02
def alternative_inference(self, input):
if 'player' not in self.__dict__:
out_size, in_size = self.weight.shape
player = Linear(in_size, out_size)
player.weight = torch.nn.Parameter(self.weight + 0.02 * self.weight.clamp(min=0))
player.bias = torch.nn.Parameter(self.bias + 0.02 * self.bias.clamp(min=0))
self.player = player
z = self(input)
zp = self.player(input)
return zp * (z / zp).data
For layernorm, I only detached the std
def alternative_inference(self, input):
mean = input.mean(dim=-1, keepdim=True)
std = input.std(dim=-1, keepdim=True)
std = std.detach()
input_norm = (input - mean) / (std + self.eps)
input_norm = input_norm * self.weight + self.bias
return input_norm
And all relevance backpropagation is done via gradient* input, the same way as you did...
this problem is also described in chefer et al. 2021 (transformer explanability beyond visualization). they explicitly handle the problem by forcing a normalizaiton at every add layer and only using LRP alpha-beta rule with alpha = 1 and beta= 0 ...
In your paper, you described results on distilledBert from huggingface, so your work should be able to run on full scale transformer (12 blocks with all the bells and whistles).. I'm wondering if you have to apply other tricks to get it working...
Thanks