NVIDIA/TransformerEngine

[Pytorch] Swiglu implementation not aligned with jiterator version in probability

tylaar opened this issue · 4 comments

Hello there, sorry to bother again, it's during my investigation to issue #709, that I found some diff for implementation of swiglu been occur when the hidden_size become larger, here is a UT to reproduce:

import torch
import transformer_engine 
from transformer_engine.pytorch import cpp_extensions as tex
from transformer_engine.pytorch.constants import TE_DType

# Adjust for different input
d1 = 256
d2 = 512

swiglu_fwd = """
template <typename T> T swiglu_fwd(T x, T y) {
    return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd = """
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
    dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
    dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd)
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd, num_outputs=2)

class MySwiglu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputmat):
        x1, x2 = torch.chunk(inputmat, 2, dim=-1)
        ctx.save_for_backward(x1, x2)
        return swiglu_fwd(x1, x2)

    @staticmethod
    def backward(ctx, dout):
        x1, x2 = ctx.saved_tensors
        return swiglu_bwd(x1, x2, dout)

myswiglu = MySwiglu.apply
input_mat1 = torch.rand(d2, d1).bfloat16().to('cuda')
input_mat2 = input_mat1.clone().detach().bfloat16().to('cuda')

r1 = myswiglu(input_mat1)
r2 = tex.swiglu(input_mat2, None, tex.FP8FwdTensors.GEMM2_INPUT, otype=TE_DType[torch.bfloat16])
print(r1)
print(r2)
print((r1-r2).nonzero())

So the thing is, I implemented a MySwiglu class, with fwd and bwd been compiled by torch.cuda.jiterator, and compared to tex version swiglu. Interesting thing is that, if you set L6 to d1 = 128, the final line diff almost didn't occur at all, while if d1 set to 256, there are some probability that some line of output is having mis-match, and when you set it to larger number for d1, let's say 1024 or 2048, diff result line appears more ...

I am not very sure if this is due to implementation bug of the jiterator's swiglu fwd, or it's something interesting inside the tex version swiglu ...

Some more setup information: torch version 2.1.0+cu122, TE version 1.3.0 python 3.9

Thanks @tylaar for reporting that. Let me take a look at that.

I did a little bit of digging and (at least looking at the case I tried) it does not seem to be a bug, but rather an artifact of the finite precision of the computations.

I modified your script a little bit to have deterministic execution and better show what happens:

import torch
import transformer_engine
from transformer_engine.pytorch import cpp_extensions as tex
from transformer_engine.pytorch.constants import TE_DType

torch.manual_seed(1234)
torch.set_printoptions(precision=10)

# Adjust for different input
d1 = 256
d2 = 512

swiglu_fwd = """
template <typename T> T swiglu_fwd(T x, T y) {
    return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd = """
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
    dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
    dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd)
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd, num_outputs=2)

class MySwiglu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputmat):
        x1, x2 = torch.chunk(inputmat, 2, dim=-1)
        ctx.save_for_backward(x1, x2)
        return swiglu_fwd(x1, x2)

    @staticmethod
    def backward(ctx, dout):
        x1, x2 = ctx.saved_tensors
        return swiglu_bwd(x1, x2, dout)

myswiglu = MySwiglu.apply
input_mat1 = torch.rand(d2, d1).bfloat16().to('cuda')
input_mat2 = input_mat1.clone().detach().bfloat16().to('cuda')

r1 = myswiglu(input_mat1)
r2 = tex.swiglu(input_mat2, None, tex.FP8FwdTensors.GEMM2_INPUT, otype=TE_DType[torch.bfloat16])
print(r1)
print(r2)
mismatches = (r1-r2).nonzero()
print(mismatches)
for m in mismatches:
    index = tuple(m)
    print(r1[index])
    print(r2[index])
    in1, in2 = torch.chunk(input_mat1, 2, dim=-1)
    x = in1.float()[index]
    y = in2.float()[index]
    out = x * y / (1 + torch.exp(-x))
    print(out)
    print(out.bfloat16())
    temp1 = 1 / (1 + torch.exp(-x))
    temp = x * temp1
    out2 = temp * y
    print(out2)
    print(out2.bfloat16())

I tried it on H100, CUDA 12.4 and TE 1.4 (we did not change the logic of swiglu in 1.4 compared with 1.3 so it should be completely equivalent).
As you noted, there are some mismatches:

tensor([[0.0128173828, 0.0776367188, 0.1162109375,  ..., 0.3769531250,
         0.0162353516, 0.0206298828],
        [0.4765625000, 0.2636718750, 0.0295410156,  ..., 0.1748046875,
         0.1196289062, 0.1044921875],
        [0.1738281250, 0.0634765625, 0.0056457520,  ..., 0.0844726562,
         0.0844726562, 0.0212402344],
        ...,
        [0.0524902344, 0.0249023438, 0.6328125000,  ..., 0.2734375000,
         0.0625000000, 0.1972656250],
        [0.1269531250, 0.0532226562, 0.0167236328,  ..., 0.0145874023,
         0.3593750000, 0.6250000000],
        [0.1166992188, 0.0213623047, 0.0751953125,  ..., 0.2236328125,
         0.1250000000, 0.3027343750]], device='cuda:0', dtype=torch.bfloat16)
tensor([[0.0128173828, 0.0776367188, 0.1162109375,  ..., 0.3769531250,
         0.0162353516, 0.0206298828],
        [0.4765625000, 0.2636718750, 0.0295410156,  ..., 0.1748046875,
         0.1196289062, 0.1044921875],
        [0.1738281250, 0.0634765625, 0.0056457520,  ..., 0.0844726562,
         0.0844726562, 0.0212402344],
        ...,
        [0.0524902344, 0.0249023438, 0.6328125000,  ..., 0.2734375000,
         0.0625000000, 0.1972656250],
        [0.1269531250, 0.0532226562, 0.0167236328,  ..., 0.0145874023,
         0.3593750000, 0.6250000000],
        [0.1166992188, 0.0213623047, 0.0751953125,  ..., 0.2236328125,
         0.1250000000, 0.3027343750]], device='cuda:0', dtype=torch.bfloat16)
tensor([[456, 127]], device='cuda:0')
tensor(0.4707031250, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4687500000, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4697265923, device='cuda:0')
tensor(0.4707031250, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4697265625, device='cuda:0')
tensor(0.4687500000, device='cuda:0', dtype=torch.bfloat16)

There is a mismatch at a place [456, 127], where your function produces 0.4707031250, whereas tex.swiglu produces 0.4687500000. While it looks like a large difference, it is actually expected - bfloat16 has small number of mantissa bits and it turns out that those 2 values are just 2 consecutive values in bfloat16 range.
Now, why we see the difference is because the FP32 value, which is the "real" result of the swiglu computation, lies very close to the boundary between these 2 bfloat16 values, so small numerical differences may result in rounding in either direction. This is shown here:

    in1, in2 = torch.chunk(input_mat1, 2, dim=-1)
    x = in1.float()[index]
    y = in2.float()[index]
    out = x * y / (1 + torch.exp(-x))
    print(out)
    print(out.bfloat16())
    temp1 = 1 / (1 + torch.exp(-x))
    temp = x * temp1
    out2 = temp * y
    print(out2)
    print(out2.bfloat16())

which produces

tensor(0.4697265923, device='cuda:0')
tensor(0.4707031250, device='cuda:0', dtype=torch.bfloat16)
tensor(0.4697265625, device='cuda:0')
tensor(0.4687500000, device='cuda:0', dtype=torch.bfloat16)

out is computed according to the formula you used in swiglu_fwd function. As you can see, it is 0.4697265923, which is 0.0009765327 away from 0.4707031250 and 0.0009765923 away from 0.4687500000. Note that those 2 differences are very close to each other, but ultimately this value is closer to the bf16 value produced by your function. out2 is computed according to the equivalent formula, but which does computation in a slightly different order (which corresponds to what we internally do in the swiglu implementation). You can see that this result in FP32 is very slightly different (0.4697265625), which in this case results in rounding going in the other direction.

Thanks @ptrendx for your dedicated investigation and reply! That makes sense to me!

Just adding a little which could be interesting on the compiler level, I did a little change on my forked TE on the implementation of swish part inside math.h, by letting swish not calling sigmoid template, instead making it looks like:

template <typename OType, typename IType>
__device__ inline OType swish(const IType val, const Empty& e) {
    const float cval = val;
    return cval * 1.f / (1.f + exp(-cval));  // Which expand the sigmoid manually here
}

The forward diff will shrink down into almost 0 probability by iterating thousands of times in d1 = 1024 and d2 = 2048

I guess it's c++ level template type casting diff which causing the diff, but I haven't look at PTX level yet. However since you've explained quite thoroughly, I think this is not going to be a issue for me then. Thanks a lot!