borgwang/tinynn

Step should average the gradients by batch size.

w32zhong opened this issue · 8 comments

It seems to me the optimizer methods, given SGD as example, use sum of gradients from a batch to multiply learning rate directly:

def _compute_step(self, grad):
    return - self.lr * grad

It is suggested we should average the grad by batch size, the benefits of doing this is listed in this post. Basically you do not have to adjust learning rate when changing batch size.

If you agree to this, I would create a pull request to add option to use mean gradients and at the same time provide compatibility to use simply sum of gradients (for efficiency consideration).

@t-k- If you look inside losses.py, the returned loss and grad are already averaged by batch size

Ah, I see. So that 1/m product is carried on since the beginning of backprop from the loss layer. Sorry I missed that part.

Wait, would it be great to multiply that 1/m in optimizer (move them from loss to optimizer)? This will make all the loss functions simpler because it seems every loss function needs to be averaged. What is your idea? @borgwang

Also, by deferring the 1/m factor to the end of backprop process, you can avoid the floating point accuracy issue since you do not want to have a very small number from the beginning of backprop.

Moreover, in this way you may also save a little computation when people try to accumulate gradients several times before apply_grad. (I would not brother to mention this if division is a cheap operation)

One disadvantage, doing so would make loss.loss and loss.grad functions inconsistent, it may confuse some people that the loss.loss has additional 1/m but the loss.grad does not have.

Another disadvantage is that if user accumulates the gradients in different batch size and then invoke apply_grad, my proposal would not handle this.

I think I will close this issue for now, since the proposed change is not simple and elegant as the current implementation, although it has some advantages, they are not that important compared to keeping the simplicity of this project.

I see your point. One possible way is to add a reduction parameter (like TF and PyTorch). But I would suggest just keeping everything simple for now.