XLA register translation rule fail
llCurious opened this issue · 1 comments
Hey, i try to add custom call and define the xla translation rule follow this doc https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#xla-compilation-rules
However, it miss the custom call part. And i try to implement this part follow your example code.
You use functools.partial(translation, platform="cpu)
here, however, i got func_wrapper() got an unexpected keyword argument 'platform'
https://github.com/dfm/extending-jax/blob/main/src/kepler_jax/kepler_jax.py#L190-L193
Could you please give some suggestions?
P.S. I use Cython to implement the C++ XLA custom call function. And the only remaining part is register the xla translation rule.
That platform="cpu"
in the translation rule is specific to this tutorial, where we write one function that can generate the translation rule for both the GPU and CPU:
extending-jax/src/kepler_jax/kepler_jax.py
Lines 59 to 62 in 77c3466
In your case, you probably just have a single translation rule, so just forget the partial
:
xla.backend_specific_translations["cpu"][YOUR_PRIMITIVE] = YOUR_TRANSLATION_RULE