openxla/xla

BF16 matmul slower than F32 matmul on T4 GPU

Opened this issue · 3 comments

T4 GPU doesn't support BF16 matmul. Because of this, XLA switches BF16 matmul to F32 matmul on T4 (IIUC). This is obviously much slower, but it turns out it's actually slower than F32 matmul (i.e. BF16 appears to be less that 50% of the speed of F32). So, there must be something else going on here. If I understand correctly, "BF16" matmul should be the same performance as F32.

I also filed an issue against JAX, since that's where I discovered this issue. jax-ml/jax#21212

As I mentioned in the other issue, you can repro on a T4 Colab with the following code:

import jax
import jax.numpy as jnp
import timeit

def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
  key = jax.random.PRNGKey(0)
  x_i = 2**exponent
  x_j = 4096
  y_j = 4096
  flop_count = x_i * x_j * y_j * 2
  x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
  y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
  matmul = jax.jit(lambda a, b: a @ b)
  matmul(x, y).block_until_ready()
  seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
  flops = flop_count / seconds_per_iter
  return flop_count, flops

def flops_to_tflops(flops):
  return flops / 1e12

for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
  print(dtype)
  for i in range(16):
    op_count, flops = flops_calc(exponent=i, dtype=dtype)
    print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
  print()

T4 GPU doesn't support BF16 matmul

It actually does, but it wouldn't use TensorCores and is incredibly slow

XLA switches BF16 matmul to F32 matmul on T4

This is a fairly recent change I did, you could try to find a commit with this. Without that change, matmuls are >4x slower from what I recall (depending on shape)

If I understand correctly, "BF16" matmul should be the same performance as F32.

Why would it? T4 has neither vector nor TensorCore support for BF16, so it has to emulate it, slowly.

Or do you mean on T4? On T4, you can look at the GPU profile.

Here the problem is we use Triton for fusions, which recently dropped support for pre-Ampere GPUs (or at least they aren't officially supported). Without fusions, we need to run an extra kernel to cast from BF16 to F32, which can be as expensive as the matmul itself.

Why would it?

Sorry, misspoke a bit. I meant that I'd expect that the emulation on T4 would be in the ballpark of (or at least not slower than) F32. But it sounds like it could be slower than F32 because of the extra cast?

Yes. Since we support cutlass fusions I might look into supporting that fusion (cast into matmul) via cutlass.