Add Gradient scaler
coreylowman opened this issue · 1 comments
coreylowman commented
With the addition of AMP<F>
dtype, we also need to add gradient scaling, which is commonly used with AMP training.
I think the frontend interface could look something like:
let mut scaler = GradientScaler { ... }; // similar fields to pytorch scalar
// this would do both parts that you have to do in pytorch now:
// 1. would scale the loss by the correct value
// 2. would unscale the gradients before returning them
grads = scaler.scaled_backward(loss);
We may have to add some methods to Gradients to support scaling them.
Originally posted by @coreylowman in #424 (comment)
coreylowman commented
Posting pytorch documentation here: https://pytorch.org/docs/stable/amp.html#gradient-scaling