Unnecessary loss of precision when computing loss functions
ariasanovsky opened this issue · 2 comments
When computing cross entropy,
pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
logits: Tensor<S, E, D, T>,
target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
let inv_last_axis_numel = 1.0 / <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
let probs = logits.log_softmax::<S::LastAxis>();
(probs * target_probs).mean().negate() / inv_last_axis_numel
}
of a tensor along along an axis of length f64
and later use it to renormalize the mean()
values from the hadamard product of probs
and target_probs
. This results in more instructions and an unnecessary loss of precision.
This could be replaced by
pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
logits: Tensor<S, E, D, T>,
target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
let last_axis_numel = <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
let probs = logits.log_softmax::<S::LastAxis>();
(probs * target_probs).mean().negate() * last_axis_numel
}
This pattern is repeated in kl_div_with_logits_loss
.
We may also be able to replace the call to mean()
with a call to sum()
, but it requires fiddling with the tensor dimensions. Using mean()
and rescaling by the last axis size calculates a sum of the form
I couldn't immediately see how to compute for S
or its axis using the existing trait bounds.
This was done specifically for f16 support - notably the max value an f16 can store is 65504, and it has low precision for generally high values.
In the scalar version of div/mul the operands will be converted to the dtype before executing - so the f64 would be converted to an f16 before actually running. By using 1 / a, we have a better chance that f16 can store the actual value.