How does haliax work with mixed precision?
Closed this issue · 3 comments
These pytorch docs have a list of fp16-safe ops and fp16-unsafe ops. I want to make sure my softmax operations run in fp32.
I read the jmp tutorial for haliax but I didn't see anything about promoting the softmax to fp32. Is this done automatically by jax? Or does haliax do this automatically somehow?
Neither. It's a fair point. In Levanter I just have flags for places where I want to upcast the op (e.g. https://github.com/stanford-crfm/levanter/blob/main/src/levanter/models/gpt2.py#L181), which I think is more or less how it's done in Flax?
When I first started designing Levanter, I thought about arrays/modules/ops having a "semantic dtype" component (output, compute, parameter) and threading jmp through, but decided against it.
If you want something transparent, Haiku has a mechanism that's worth checking out it uses context mappings on ops to do it.
What are your thoughts?
I'm very new to Jax and have only used Equinox, without looking much at Flax or Haiku yet. I ended up simply casting everything to bfloat16 since my training runs were diverging with fp16, even when manually upcasting softmax and layernorms.
I think manually upcasting in model definitions is probably the best practice. I'm used to PyTorch, where I often don't write models from scratch anymore because paper authors provide fairly optimized implementations. But I guess it's fine to write models from scratch in Jax because XLA will optimize the CUDA ops and such.
Thanks for the discussion!