dfm/extending-jax

Interested in help updating these instructions to the new style of XLA translation rules?

Closed this issue · 1 comments

EiffL commented

Hey @dfm :-) I had to do a bit of hacking this week, read a lot of XLA source code, and played with the new MLIR approach for specifying translation rules for custom ops in JAX. My understanding is that since April, all builtin LAX primitives have been transferred to MLIR equivalents, and that the old style CustomCallWithLayout just remains for backward compatibility. Here is an example of what the custom calls look like in current jax:
https://github.com/google/jax/blob/f697b8e0876f8e1144a53ace02ee6d7eaa43fa14/jaxlib/gpu_solver.py#L66

Before the knowledge of how to make these things work leaves my short term memory, would you be interested in something like a PR to this post? If you prefer these posts to stay static, no worries, I can write down that info elsewhere, linking to your post for extended context ;-)

dfm commented

👋 @EiffL — Thanks for this note! Yes - that would be awesome to include here and I'd be very happy if you'd be up for writing something.