PJRT C API does not support GetDefaultLayout, however it is called (always) by DLPackManagedTensorToBuffer
Opened this issue · 4 comments
hello!
I'm trying to write a simple CUDA/Python module (coincidentally using https://github.com/wjakob/nanobind to provide the dlpack integration) and I have a method that returns a cuda-allocated array to jax, via dlpack. however, when jax tries to construct an ndarray from the dlpack object returned by nanobind, we hit an exception
UNIMPLEMENTED: PJRT C API does not support GetDefaultLayout
I traced the flow of code as follows -
jax.dlpack calls
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
in its dlpack.py, which seems innocuous enough.
dlpack_managed_tensor_to_buffer
is a wrapper around DLPackManagedTensorToBuffer
which is defined in the xla/xla/python/dlpack.cc
file of the openxla project; sadly that function works by validating if the strides you passed in are equal to the default ones, just as a sanity check. how does it work out the default strides? why, it calls (always!) GetDefaultLayout, here
Lines 418 to 423 in ac8fb1f
xla/xla/pjrt/pjrt_c_api_client.h
Lines 286 to 290 in 886a191
jax.dlpack
, im surprised its not implemented, and/or i haven't found an existing issue for this.
im afraid im fairly new to jax, dlpack, and nanobind, so forgive me if I am making some dumb mistake. how would you suggest I proceed? regardless of this particular code path, the problem I am trying to solve is to return a cuda array to jax via dlpack, ideally using nanobind as a wrapper over dlpack. if there's another way to do that, that would work too :)
the TODO is listed to @skye so forgive me for @'ing you.
How are you installing jax? If you use the standard cuda install (https://github.com/google/jax?tab=readme-ov-file#instructions), that does not use the PJRT C API and should not have this problem.
I'm also working on fixing this in the PJRT C API so we can switch to a cuda plugin that does use the C API, but you shouldn't need to wait for that.
Ah ok, [cuda12]
is the plugin install (we kept it simple because we want this to be the one true cuda install once we fix these C API issues). Try:
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
(cuda12_pip
instead of just cuda12
)
Closing this issue. Please reopen if you're still having issues!
EDIT: jk, I don't have permissions to close openxla issues :) Please let me know if there's still an issue, otherwise I think we can go ahead and close it!