Performance of MNIST example
mar-muel opened this issue · 1 comments
Hi everyone - thanks for your work on this, very exciting!
I've been playing around a bit with the Flax MNIST example (https://github.com/google/aqt/blob/main/aqt/jax/v2/examples/mnist.py). I've benchmarked the training (as well as eval) on TPU v4 and v5 and can't see a performance improvement compared to bfloat16/float32 training. Both training and eval are around 4% slower when using int8 quantized operations.
Am I doing something wrong or is this expected? I could imagine that the overhead of converting from float32 to int8 and back is non-negligible at this small scale.
This is expected!
AQT Quantization has its overheads which are O(N^2) complexity (elementwise).
While the matmul we are speeding up is O(N^3).
These matmuls in mnist are way too small (N is too small) to demonstrate a speedup savings.
Mnist there is only to demonstrate APIs.
If you want something more realistic I recommend MaxText codebase https://github.com/google/maxtext
AQT is integrated there, and that's how we obtained results in the blog post:
https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e/
Also you need TPU v5 indeed for any speedup.