keras-team/keras-nlp

Gemma discrepancies

Closed this issue · 1 comments

Describe the bug
Recently unsloth.ai has pointed out some discrepancies in Gemma model implementations. I think it would be nice to have them verified and addressed.

Specially, they mentioned Keras mixed_bfloat16 RoPE is wrong. According to their findings, TPUs on Colab cast RoPE positions in int32, whilst in Keras, they’re cast to compute_dtype (bfloat16), causing [8190, 8191] to become [8192, 8192]

Keras’s incorrect line casting positions to bfloat16
Deepmind uses int32 for positions and I also verified on TPUs its int32

RoPE_Precision_qQH2lVYi4pol4uB136Iuo

Hi @awsaf49, thank you for bringing this up! We've actually fixed this in #1497 and #1508. Thank you!