Question about the noisy top-k gating
Closed this issue · 3 comments
Hi! Thanks for your implementation of MoE! I have confused about the derivatives of w_gate and w_noise. It seems the computation of logits:
https://github.com/davidmrau/mixture-of-experts/blob/master/moe.py#L239
is not differentiable because of the top-k operation. So the w_gate and w_noise can not be updated from the NLLloss. Not sure is the appropriate way to train the MoE.
Please have a look at the original paper. The idea is to approximate the hard top-k gating with noisy smoothing during training with the function noisy_top_k_gating. Hope this helps.
Hi @davidmrau
can u please elaborate on what did u mean by your comment? I was not able to find how the gradients are backpropagated from the paper.
Thanks in advance
Sorry if I was not clear enough. For the load balancing, you are adding the noisy gating. The Top-k gating is indeed not differentiable are stated in the paper: While this form of sparsity creates some theoretically scary discontinuities in the output of gating function, we have not yet observed this to be a problem in practice. The noise term helps with load balancing, as will be discussed in Appendix A.