coreylowman/dfdx

Unnecessary loss of precision when computing loss functions

ariasanovsky opened this issue · 2 comments

When computing cross entropy,

pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    logits: Tensor<S, E, D, T>,
    target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
    let inv_last_axis_numel = 1.0 / <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
    let probs = logits.log_softmax::<S::LastAxis>();
    (probs * target_probs).mean().negate() / inv_last_axis_numel
}

of a tensor along along an axis of length $a$, we calculate $1/a$ as a f64 and later use it to renormalize the mean() values from the hadamard product of probs and target_probs. This results in more instructions and an unnecessary loss of precision.

This could be replaced by

pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    logits: Tensor<S, E, D, T>,
    target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
    let last_axis_numel = <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
    let probs = logits.log_softmax::<S::LastAxis>();
    (probs * target_probs).mean().negate() * last_axis_numel
}

This pattern is repeated in kl_div_with_logits_loss.

We may also be able to replace the call to mean() with a call to sum(), but it requires fiddling with the tensor dimensions. Using mean() and rescaling by the last axis size calculates a sum of the form $$\text{mean} \cdot \text{last axis numel} = \left(\dfrac{1}{|A||I|}\sum_{(a, i)\in A\times I}p_{a, i}\right)\cdot |A|$$ whereas we could instead compute $$\text{sum} / \text{foo} = \left(\sum_{(a, i)\in A\times I}p_{a, i}\right)\cdot |I|^{-1}.$$

I couldn't immediately see how to compute for $|I|$ using the generic parameter S or its axis using the existing trait bounds.

This was done specifically for f16 support - notably the max value an f16 can store is 65504, and it has low precision for generally high values.

In the scalar version of div/mul the operands will be converted to the dtype before executing - so the f64 would be converted to an f16 before actually running. By using 1 / a, we have a better chance that f16 can store the actual value.