Determinism in Gemma model
Closed this issue · 1 comments
I believe this is a side effect of how JAX (and XLA) compute matrix-matrix products vs matrix-vector products. In short, batched matrix multiplication can cause XLA to use a different algorithm to compute the products, which can slightly change results due to the non-associativity of floating point operations.
Here's a related comment that seems to describe the same issue: jax-ml/jax#20047 (comment)
Relevant bit:
My gut feeling is that this is working as intended. Adding a batch dimension will allow and encourage XLA to change the order of operations, and floating point computations will not produce bitwise exact results if you change the order of operations. Indeed: that's sort of the point of vmap: we can and will compute things in a different and possibly more efficient order if there's a batch dimension.
The most common place this surprises people is that adding a batch dimension to a matrix-vector multiplication makes it a matrix-matrix multiplication, which triggers different and sometimes lower precision matmul algorithms especially on GPUs and TPUs.
The size of the difference here is probably because you are using bfloat16
, which is a very low precision format and can lead to cascading differences in model outputs. You could try initializing the model with upcast_activations_to_float32=True
, which will use a higher precision to store the model activations and should lead to smaller discrepancies (at the cost of somewhat slower speed and higher memory usage).