coreylowman/dfdx

Add Gradient scaler

coreylowman opened this issue · 1 comments

With the addition of AMP<F> dtype, we also need to add gradient scaling, which is commonly used with AMP training.

I think the frontend interface could look something like:

let mut scaler = GradientScaler { ... }; // similar fields to pytorch scalar

// this would do both parts that you have to do in pytorch now:
// 1. would scale the loss by the correct value
// 2. would unscale the gradients before returning them
grads = scaler.scaled_backward(loss);

We may have to add some methods to Gradients to support scaling them.

Originally posted by @coreylowman in #424 (comment)