graphcore-research/jax-scalify

Support mixed normal/scaled graph in `AutoScale`

Closed this issue · 0 comments

The AutoScale interpreter needs to be generalized to support mixed graph, where some tensors are still using normal JAX arrays.

It means we need some form of rules + promotions related to:

  • When to use scaled primitives;
  • When to automatically promote simple arrays to ScaledArray;