yizhilll/MERT

Regarding attention relaxation

PetrosVav opened this issue · 0 comments

Hello @yizhilll,

First of all very nice work!

I was looking through the code and I have a question about the attention relaxation process. In the implementation of the att_relax process you write the following:

attn_weights_relax = attn_weights / self.attention_relax
attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
attn_weights_float = utils.softmax(
    attn_weights, dim=-1, onnx_trace=self.onnx_trace
)

In the above code snippet, given that you divide and multiply again by the same constant self.attention_relax, this value is cancelled out and you end up in the following expression:
softmax(attn_weights - max(attn_weights)) -> softmax(attn_weights), where the attn_weights are the original weights, due to the property of softmax being invariant under translation by the same value in each coordinate.

Given that self.attention_relax is cancelled out in the formula and plays no role in the numerical stabilization of the softmax calculation, shouldn't it be omitted?

Thank you!