google/aqt

Performance of MNIST example

mar-muel opened this issue · 1 comments

Hi everyone - thanks for your work on this, very exciting!

I've been playing around a bit with the Flax MNIST example (https://github.com/google/aqt/blob/main/aqt/jax/v2/examples/mnist.py). I've benchmarked the training (as well as eval) on TPU v4 and v5 and can't see a performance improvement compared to bfloat16/float32 training. Both training and eval are around 4% slower when using int8 quantized operations.

Am I doing something wrong or is this expected? I could imagine that the overhead of converting from float32 to int8 and back is non-negligible at this small scale.

This is expected!
AQT Quantization has its overheads which are O(N^2) complexity (elementwise).
While the matmul we are speeding up is O(N^3).
These matmuls in mnist are way too small (N is too small) to demonstrate a speedup savings.

Mnist there is only to demonstrate APIs.

If you want something more realistic I recommend MaxText codebase https://github.com/google/maxtext
AQT is integrated there, and that's how we obtained results in the blog post:
https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e/

Also you need TPU v5 indeed for any speedup.