Regarding attention relaxation
PetrosVav opened this issue · 0 comments
Hello @yizhilll,
First of all very nice work!
I was looking through the code and I have a question about the attention relaxation process. In the implementation of the att_relax process you write the following:
attn_weights_relax = attn_weights / self.attention_relax
attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
In the above code snippet, given that you divide and multiply again by the same constant self.attention_relax
, this value is cancelled out and you end up in the following expression:
softmax(attn_weights - max(attn_weights)) -> softmax(attn_weights)
, where the attn_weights are the original weights, due to the property of softmax being invariant under translation by the same value in each coordinate.
Given that self.attention_relax
is cancelled out in the formula and plays no role in the numerical stabilization of the softmax calculation, shouldn't it be omitted?
Thank you!