openxla/xla

PJRT C API does not support GetDefaultLayout, however it is called (always) by DLPackManagedTensorToBuffer

Opened this issue · 4 comments

hello!
I'm trying to write a simple CUDA/Python module (coincidentally using https://github.com/wjakob/nanobind to provide the dlpack integration) and I have a method that returns a cuda-allocated array to jax, via dlpack. however, when jax tries to construct an ndarray from the dlpack object returned by nanobind, we hit an exception
UNIMPLEMENTED: PJRT C API does not support GetDefaultLayout
I traced the flow of code as follows -
jax.dlpack calls

    return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
        dlpack, cpu_backend, gpu_backend))

in its dlpack.py, which seems innocuous enough.
dlpack_managed_tensor_to_buffer is a wrapper around DLPackManagedTensorToBuffer which is defined in the xla/xla/python/dlpack.cc file of the openxla project; sadly that function works by validating if the strides you passed in are equal to the default ones, just as a sanity check. how does it work out the default strides? why, it calls (always!) GetDefaultLayout, here

xla/xla/python/dlpack.cc

Lines 418 to 423 in ac8fb1f

// Raise an error if the resulting PjRtBuffer would have a non-default layout.
// TODO(skyewm): we do this because JAX doesn't currently have good support
// for non-default layouts, and will return wrong results if a non-default
// layout is passed to a computation expecting default layouts. Remove this
// special case when non-default layouts are better supported by JAX.
TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout(
which leads to this line
StatusOr<Layout> GetDefaultLayout(PrimitiveType element_type,
absl::Span<const int64_t> dims) override {
// TODO(skyewm): implement
return Unimplemented("PJRT C API does not support GetDefaultLayout");
}
'TODO: implement'. considering that this function is called by the constructor of ndarrays in jax.dlpack, im surprised its not implemented, and/or i haven't found an existing issue for this.

im afraid im fairly new to jax, dlpack, and nanobind, so forgive me if I am making some dumb mistake. how would you suggest I proceed? regardless of this particular code path, the problem I am trying to solve is to return a cuda array to jax via dlpack, ideally using nanobind as a wrapper over dlpack. if there's another way to do that, that would work too :)

the TODO is listed to @skye so forgive me for @'ing you.

skye commented

How are you installing jax? If you use the standard cuda install (https://github.com/google/jax?tab=readme-ov-file#instructions), that does not use the PJRT C API and should not have this problem.

I'm also working on fixing this in the PJRT C API so we can switch to a cuda plugin that does use the C API, but you shouldn't need to wait for that.

skye commented

Ah ok, [cuda12] is the plugin install (we kept it simple because we want this to be the one true cuda install once we fix these C API issues). Try:

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

(cuda12_pip instead of just cuda12)

skye commented

Closing this issue. Please reopen if you're still having issues!

EDIT: jk, I don't have permissions to close openxla issues :) Please let me know if there's still an issue, otherwise I think we can go ahead and close it!