graphcore-research/jax-scalify

Generalize `Scalify` to support MX microscaling formats

Opened this issue · 5 comments

We would want to support in the Scalify formalism the MX formats: https://arxiv.org/abs/2310.10537
Normally, it should be as "simple" as allowing the scale array/tensor to be broadcastable to the data component.

Should there be stronger restrictions e.g., block size=32, or would you rather have MX formats as a special case of a broadcastable scale?

Some ops e.g. reduces would have different implementations depending on whether scale is a scalar or a broadcastable vector. Should these be handled by separate ops and let the interpreter infer which is the correct version to insert?

I think we should ScaledArray support broadcastable scale, and hence MX formats just becomes a simple case.

Then for ops implementation: we don't need to cover everything. We can first say that we only support matmul and cast with broadcastable scale, and nothing else. And as this op would returned a (scalar) scaled array FP16 or BF16, it is fine.
Similarly to ops support of fp8 being minimal in most framework on the moment. We can then expand in the directions we see fit.

In my mind: the interpreter is doing the minimal job of tracing + transforming the graph. Scaled ops are responsible for checking the inputs are supported (and raise an error if necessary)

One design issue: some information needs to kept to say how scale should be broadcastable against data, e.g., a 32x32 matrix could be scaled row-wise or column-wise by a 4-element vector. This could be tracked with singleton dimensions e.g, 4x1 and 1x4 size vectors.

This wouldn't generalise if blocks need to cross rows or columns, e.g., a 10x15 matrix scaled by a 6-element vector would require some reshaping of the data matrix to be broadcastable. It's possible, but you'd need to add some information.

Should there be some attribute attached to the ScaledArray to keep track of this? Or keep it implicit in the shape of scale?