BF16 matmul slower than F32 matmul on T4 GPU
Opened this issue · 3 comments
T4 GPU doesn't support BF16 matmul. Because of this, XLA switches BF16 matmul to F32 matmul on T4 (IIUC). This is obviously much slower, but it turns out it's actually slower than F32 matmul (i.e. BF16 appears to be less that 50% of the speed of F32). So, there must be something else going on here. If I understand correctly, "BF16" matmul should be the same performance as F32.
I also filed an issue against JAX, since that's where I discovered this issue. jax-ml/jax#21212
As I mentioned in the other issue, you can repro on a T4 Colab with the following code:
import jax
import jax.numpy as jnp
import timeit
def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
key = jax.random.PRNGKey(0)
x_i = 2**exponent
x_j = 4096
y_j = 4096
flop_count = x_i * x_j * y_j * 2
x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
matmul = jax.jit(lambda a, b: a @ b)
matmul(x, y).block_until_ready()
seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
flops = flop_count / seconds_per_iter
return flop_count, flops
def flops_to_tflops(flops):
return flops / 1e12
for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
print(dtype)
for i in range(16):
op_count, flops = flops_calc(exponent=i, dtype=dtype)
print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
print()
T4 GPU doesn't support BF16 matmul
It actually does, but it wouldn't use TensorCores and is incredibly slow
XLA switches BF16 matmul to F32 matmul on T4
This is a fairly recent change I did, you could try to find a commit with this. Without that change, matmuls are >4x slower from what I recall (depending on shape)
If I understand correctly, "BF16" matmul should be the same performance as F32.
Why would it? T4 has neither vector nor TensorCore support for BF16, so it has to emulate it, slowly.
Or do you mean on T4? On T4, you can look at the GPU profile.
Here the problem is we use Triton for fusions, which recently dropped support for pre-Ampere GPUs (or at least they aren't officially supported). Without fusions, we need to run an extra kernel to cast from BF16 to F32, which can be as expensive as the matmul itself.
Why would it?
Sorry, misspoke a bit. I meant that I'd expect that the emulation on T4 would be in the ballpark of (or at least not slower than) F32. But it sounds like it could be slower than F32 because of the extra cast?
Yes. Since we support cutlass fusions I might look into supporting that fusion (cast into matmul) via cutlass.