tensorflow/mlir-hlo

mlir in tensorflow training?

Cuttstage opened this issue · 5 comments

Hi, Is it possible for using mlir to make LLVM IR in machine training like GPU support?
I can not find any code in tensroflow to use mlir turn back to tensorflow executor. Therefore, mlir is only useful for inference? I wonder if tensorflow could change make some ops into IR process and then merge IR result back to tensorflow process?

MLIR is used in Grappler now, which is plugged into the executor. MLIR is also used more and more to implement the XLA CPU/GPU compiler (which emits LLVM), and it is used for both inference and training (it powers JAX for example).

You'd have to be more specific about what you're trying to achieve though? MLIR has been used so far for fairly low-level pieces of infrastructure inside the TensorFlow/XLA ecosystem.

Seems like you have some misconfigured auto-reply here :)

MLIR is used in Grappler now, which is plugged into the executor. MLIR is also used more and more to implement the XLA CPU/GPU compiler (which emits LLVM), and it is used for both inference and training (it powers JAX for example).

You'd have to be more specific about what you're trying to achieve though? MLIR has been used so far for fairly low-level pieces of infrastructure inside the TensorFlow/XLA ecosystem.

Thanks for your reply. I have found some code about tfg in the lastest tensorflow repo. I am a newer to MLIR and hope to use this feature in our machine train. :)

MHLO (which is what this repo contains) is the native IR of JAX, which is used heavily for training (on both CPU/GPU/TPU via XLA). However, "training" can mean many things. For example, here is a prototype of a new API we are working on for saving off an entire Jax training program (in this case, a simple mnist model) so it can be run offline via IREE on single CPU/GPU systems (which happens to include a large swath of mobile and embedded devices):

This is but one example of a training setup. If looking for distributed training, that is a more advanced topic. Also, the above is just a prototype: we are looking to finish it for everyone to use in the coming months.

MHLO (which is what this repo contains) is the native IR of JAX, which is used heavily for training (on both CPU/GPU/TPU via XLA). However, "training" can mean many things. For example, here is a prototype of a new API we are working on for saving off an entire Jax training program (in this case, a simple mnist model) so it can be run offline via IREE on single CPU/GPU systems (which happens to include a large swath of mobile and embedded devices):

This is but one example of a training setup. If looking for distributed training, that is a more advanced topic. Also, the above is just a prototype: we are looking to finish it for everyone to use in the coming months.

Thanks a lot. We will have a try. But right now our producing model is running on tf. It will be a great work to move tf to jax.