astramind-ai/Mixture-of-depths

Misaligned implementation with paper

starsholic opened this issue · 1 comments

Hi! Here in paper 3.4 eq1,
1714224387010

should be
output = processed_tokens + x
I wonder why only add selected tokens here

output = processed_tokens + (x * (~selected_mask).unsqueeze(-1).to(x.dtype))

Hello,
Thank you so much for pointing this out. I agree with the issue and the solution you mentioned. Have you proceeded with this solution because I think the implementation still has this issue?