keras-team/tf-keras

Support for StableHLO generation with JAX backend

Opened this issue · 2 comments

System information.

TensorFlow version (you are using): N/A
Are you willing to contribute it (Yes/No): Not immediately.

Describe the feature and the current behavior/state.

In some of the Keras Core examples JAX backend has been used. IIUC this flow uses jitting via XLA. Here I assume that the lowering must generate StableHLO as an IR before its consumed by XLA. If this is truly the case, is it viable to produce StableHLO while using JAX as the backend? It will be useful for compiling the same model using IREE instead of XLA.

Existing flow: Keras Core model --> JAX.JIT (XLA)
Desired flow: Keras Core model --> JAX.JIT --> Side outputs is StableHLO --> IREE

Will this change the current api? How?

I am not sure.

Who will benefit from this feature?

IREE users.

Contributing

  • Do you want to contribute a PR? (yes/no): no
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing):

Thanks for filing the issue. We have model.export support right now. We plan on adding support for Onnx support and stableHLO support in the future.

Is the export support capable of generating PyTorch and JAX programs from the original Keras Core models?