dfm/extending-jax

XLA register translation rule fail

llCurious opened this issue · 1 comments

Hey, i try to add custom call and define the xla translation rule follow this doc https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#xla-compilation-rules

However, it miss the custom call part. And i try to implement this part follow your example code.

You use functools.partial(translation, platform="cpu) here, however, i got func_wrapper() got an unexpected keyword argument 'platform'
https://github.com/dfm/extending-jax/blob/main/src/kepler_jax/kepler_jax.py#L190-L193

Could you please give some suggestions?

P.S. I use Cython to implement the C++ XLA custom call function. And the only remaining part is register the xla translation rule.

dfm commented

That platform="cpu" in the translation rule is specific to this tutorial, where we write one function that can generate the translation rule for both the GPU and CPU:

# We also need a translation rule to convert the function into an XLA op. In
# our case this is the custom XLA op that we've written. We're wrapping two
# translation rules into one here: one for the CPU and one for the GPU
def _kepler_translation(c, mean_anom, ecc, *, platform="cpu"):

In your case, you probably just have a single translation rule, so just forget the partial:

xla.backend_specific_translations["cpu"][YOUR_PRIMITIVE] = YOUR_TRANSLATION_RULE