BF16 matmul slower than F32 matmul on T4 GPU
Closed this issue · 4 comments
Description
BF16 matmul appears to be slower than F32 matmul on T4. From my test, BF16 appears to be half the speed. I believe this is a bug and bf16 should be the same speed (or possibly better) than f32.
You can repro in a T4 colab with the following:
import jax
import jax.numpy as jnp
import timeit
def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
key = jax.random.PRNGKey(0)
x_i = 2**exponent
x_j = 4096
y_j = 4096
flop_count = x_i * x_j * y_j * 2
x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
matmul = jax.jit(lambda a, b: a @ b)
matmul(x, y).block_until_ready()
seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
flops = flop_count / seconds_per_iter
return flop_count, flops
def flops_to_tflops(flops):
return flops / 1e12
for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
print(dtype)
for i in range(16):
op_count, flops = flops_calc(exponent=i, dtype=dtype)
print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
print()
This results in the following output:
<class 'jax.numpy.bfloat16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.10
Total TFLOP Count: 0.00007 | TFLOPS: 0.04
Total TFLOP Count: 0.00013 | TFLOPS: 0.09
Total TFLOP Count: 0.00027 | TFLOPS: 0.16
Total TFLOP Count: 0.00054 | TFLOPS: 0.35
Total TFLOP Count: 0.00107 | TFLOPS: 0.61
Total TFLOP Count: 0.00215 | TFLOPS: 1.09
Total TFLOP Count: 0.00429 | TFLOPS: 1.22
Total TFLOP Count: 0.00859 | TFLOPS: 1.74
Total TFLOP Count: 0.01718 | TFLOPS: 2.27
Total TFLOP Count: 0.03436 | TFLOPS: 2.36
Total TFLOP Count: 0.06872 | TFLOPS: 2.36
Total TFLOP Count: 0.13744 | TFLOPS: 2.16
Total TFLOP Count: 0.27488 | TFLOPS: 2.19
Total TFLOP Count: 0.54976 | TFLOPS: 2.14
Total TFLOP Count: 1.09951 | TFLOPS: 2.09
<class 'jax.numpy.float16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.11
Total TFLOP Count: 0.00007 | TFLOPS: 0.22
Total TFLOP Count: 0.00013 | TFLOPS: 0.44
Total TFLOP Count: 0.00027 | TFLOPS: 0.92
Total TFLOP Count: 0.00054 | TFLOPS: 1.76
Total TFLOP Count: 0.00107 | TFLOPS: 3.53
Total TFLOP Count: 0.00215 | TFLOPS: 6.99
Total TFLOP Count: 0.00429 | TFLOPS: 14.04
Total TFLOP Count: 0.00859 | TFLOPS: 23.47
Total TFLOP Count: 0.01718 | TFLOPS: 25.02
Total TFLOP Count: 0.03436 | TFLOPS: 35.24
Total TFLOP Count: 0.06872 | TFLOPS: 37.16
Total TFLOP Count: 0.13744 | TFLOPS: 31.20
Total TFLOP Count: 0.27488 | TFLOPS: 24.41
Total TFLOP Count: 0.54976 | TFLOPS: 23.02
Total TFLOP Count: 1.09951 | TFLOPS: 22.13
<class 'jax.numpy.float32'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.08
Total TFLOP Count: 0.00007 | TFLOPS: 0.16
Total TFLOP Count: 0.00013 | TFLOPS: 0.31
Total TFLOP Count: 0.00027 | TFLOPS: 0.66
Total TFLOP Count: 0.00054 | TFLOPS: 1.34
Total TFLOP Count: 0.00107 | TFLOPS: 2.61
Total TFLOP Count: 0.00215 | TFLOPS: 4.18
Total TFLOP Count: 0.00429 | TFLOPS: 4.92
Total TFLOP Count: 0.00859 | TFLOPS: 5.32
Total TFLOP Count: 0.01718 | TFLOPS: 4.59
Total TFLOP Count: 0.03436 | TFLOPS: 4.31
Total TFLOP Count: 0.06872 | TFLOPS: 4.19
Total TFLOP Count: 0.13744 | TFLOPS: 4.04
Total TFLOP Count: 0.27488 | TFLOPS: 4.30
Total TFLOP Count: 0.54976 | TFLOPS: 4.31
Total TFLOP Count: 1.09951 | TFLOPS: 4.37
Note how bf16 is much slower than f32. (side note: I also see that bf16 is way slower than f16, but my understanding is that it's because t4 doesn't support bf16, so JAX alters the computation to use f32).
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='063d876e5268', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon May 13 18:04:34 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 75C P0 30W / 70W | 11457MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
Thanks for the report!
Here's a ref to the T4 architecture spec: https://images.nvidia.com/aem-dam/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf
T4 doesn't support bfloat16, but JAX (via the XLA GPU compiler) should be falling back to float32. The fact that the result is appreciably slower than native float32 may indicate a bug in the XLA GPU compiler.
I'd suggest reporting at https://github.com/openxla/xla
Replied on the OpenXLA bug.
I tested mentioned code on Google Colab T4 GPU with JAX version 0.4.26, 0.4.27 and later. From the JAX version 0.4.27, the speed of BF16 matmul is almost similar to that of F32 matmul on T4 GPU. I am getting the following output when tested with JAX 0.4.27:
import jax
import jax.numpy as jnp
import timeit
print(jax.__version__)
def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
key = jax.random.PRNGKey(0)
x_i = 2**exponent
x_j = 4096
y_j = 4096
flop_count = x_i * x_j * y_j * 2
x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
matmul = jax.jit(lambda a, b: a @ b)
matmul(x, y).block_until_ready()
seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
flops = flop_count / seconds_per_iter
return flop_count, flops
def flops_to_tflops(flops):
return flops / 1e12
for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
print(dtype)
for i in range(16):
op_count, flops = flops_calc(exponent=i, dtype=dtype)
print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
print()
Output:
0.4.27
<class 'jax.numpy.bfloat16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.10
Total TFLOP Count: 0.00007 | TFLOPS: 0.06
Total TFLOP Count: 0.00013 | TFLOPS: 0.13
Total TFLOP Count: 0.00027 | TFLOPS: 0.23
Total TFLOP Count: 0.00054 | TFLOPS: 0.46
Total TFLOP Count: 0.00107 | TFLOPS: 0.88
Total TFLOP Count: 0.00215 | TFLOPS: 1.47
Total TFLOP Count: 0.00429 | TFLOPS: 1.95
Total TFLOP Count: 0.00859 | TFLOPS: 4.33
Total TFLOP Count: 0.01718 | TFLOPS: 5.20
Total TFLOP Count: 0.03436 | TFLOPS: 5.17
Total TFLOP Count: 0.06872 | TFLOPS: 4.74
Total TFLOP Count: 0.13744 | TFLOPS: 4.69
Total TFLOP Count: 0.27488 | TFLOPS: 4.98
Total TFLOP Count: 0.54976 | TFLOPS: 4.96
Total TFLOP Count: 1.09951 | TFLOPS: 4.89
<class 'jax.numpy.float16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.11
Total TFLOP Count: 0.00007 | TFLOPS: 0.22
Total TFLOP Count: 0.00013 | TFLOPS: 0.44
Total TFLOP Count: 0.00027 | TFLOPS: 0.77
Total TFLOP Count: 0.00054 | TFLOPS: 1.66
Total TFLOP Count: 0.00107 | TFLOPS: 3.32
Total TFLOP Count: 0.00215 | TFLOPS: 6.68
Total TFLOP Count: 0.00429 | TFLOPS: 10.77
Total TFLOP Count: 0.00859 | TFLOPS: 14.33
Total TFLOP Count: 0.01718 | TFLOPS: 22.38
Total TFLOP Count: 0.03436 | TFLOPS: 29.81
Total TFLOP Count: 0.06872 | TFLOPS: 31.90
Total TFLOP Count: 0.13744 | TFLOPS: 27.16
Total TFLOP Count: 0.27488 | TFLOPS: 24.00
Total TFLOP Count: 0.54976 | TFLOPS: 23.06
Total TFLOP Count: 1.09951 | TFLOPS: 24.39
<class 'jax.numpy.float32'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.08
Total TFLOP Count: 0.00007 | TFLOPS: 0.15
Total TFLOP Count: 0.00013 | TFLOPS: 0.32
Total TFLOP Count: 0.00027 | TFLOPS: 0.59
Total TFLOP Count: 0.00054 | TFLOPS: 1.14
Total TFLOP Count: 0.00107 | TFLOPS: 1.53
Total TFLOP Count: 0.00215 | TFLOPS: 2.22
Total TFLOP Count: 0.00429 | TFLOPS: 2.44
Total TFLOP Count: 0.00859 | TFLOPS: 2.64
Total TFLOP Count: 0.01718 | TFLOPS: 2.74
Total TFLOP Count: 0.03436 | TFLOPS: 2.94
Total TFLOP Count: 0.06872 | TFLOPS: 3.88
Total TFLOP Count: 0.13744 | TFLOPS: 4.17
Total TFLOP Count: 0.27488 | TFLOPS: 4.54
Total TFLOP Count: 0.54976 | TFLOPS: 4.46
Total TFLOP Count: 1.09951 | TFLOPS: 4.51
Please find the gist for reference.
Thank you.
Perfect, thanks! I was able to repro post 0.4.27! I think we can close this out.