microsoft/torchscale

Question about the normalization in attention

Cranial-XIX opened this issue · 2 comments

Dear authors,

Nice work! I have a few questions regarding the normalization in the implementation of the RetNet and would like to consult your ideas about them:

Here,

qk_mat = qk_mat / qk_mat.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)

why there is a normalization for attention? Is it mentioned in the paper anywhere? Why do you choose abs()?

Here,

value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)

for the decay term, why do you normalize the mask here? Isn't the unnormalized mask correct?

Here,

scale = mask.sum(dim=-1, keepdim=True).sqrt()

what is the role of this scale variable and why is it divided both in the inner_mask and query?

Thank you very much in advance!

It's a little bit complicated. The key idea is to align the normalized computation to be identical to parallel representation. Due to fp16 data range limitation, in chunkwise forward, we have to normalize data range in different part. So, after these modification, the output will be same as totally naive version.

We mention this problem in Retention Score Normalization of our paper.

Thanks a lot!