pabloferz/DLPack.jl

`CuArray` has no field storage (PyCall, CUDA, Jax)

samuelfneumann opened this issue · 1 comments

Hi,

I'm trying to convert a Julia CuArray to a Jax array using DLPack.jl. Here is a minimal working example of what I am doing:

using DLPack, CUDA, PyCall

dl = pyimport("jax.dlpack")
to_jax(o) = DLPack.share(o, dl.from_dlpack)

arr = cu(rand(10))
arr |> to_jax

When I run exactly this code, I get the following error message

ERROR: type CuArray has no field storage
Stacktrace:
  [1] getproperty
    @ ./Base.jl:37 [inlined]
  [2] dldevice(B::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ DLPack ~/.julia/packages/DLPack/SUhao/src/cuda.jl:12
  [3] unsafe_share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ DLPack ~/.julia/packages/DLPack/SUhao/src/DLPack.jl:253
  [4] share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ DLPack ~/.julia/packages/DLPack/SUhao/src/cuda.jl:4
  [5] share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, #unused#::Type{PyObject}, from_dlpack
::PyObject)
    @ DLPack ~/.julia/packages/DLPack/SUhao/src/pycall.jl:99
  [6] share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, from_dlpack::PyObject)
    @ DLPack ~/.julia/packages/DLPack/SUhao/src/pycall.jl:89
  [7] to_jax(o::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Main ./REPL[3]:1
  [8] |>(x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, f::typeof(to_jax))
    @ Base ./operators.jl:907
  [9] top-level scope
    @ REPL[10]:1
 [10] top-level scope
    @ ~/.julia/packages/CUDA/YIj5X/src/initialization.jl:208

It appears that the error originates in line 12 of src/cuda.jl. After some fiddling around, I was able to successfully convert the CuArray to Jax when I overrode the dldevice(B::CUDA.StridedCuArray) function with:

buftype(::CuArray{F,I,B}) where {F,I,B} = B

function DLPack.dldevice(B::CUDA.StridedCuArray)
    buf = buftype(B)

    dldt = if buf isa Type{CUDA.Mem.DeviceBuffer}
        DLPack.kDLCUDA
    elseif buf isa Type{CUDA.Mem.HostBuffer}
        DLPack.kDLCUDAHost
    elseif buf isa Type{CUDA.Mem.UnifiedBuffer}
        DLPack.kDLCUDAManaged
    end

    return DLPack.DLDevice(dldt, CUDA.device(B))
end

Looking through the old versions of CUDA.jl it appears that the storage field was removed in the CuArray, which explains where this error is coming from. I'm not completely certain that this is correct though, but maybe someone can point me in the right direction, and whether or not I should open a PR.

I'm using the following package versions

[438e738f] PyCall v1.96.3
[53c2dc0f] DLPack v0.1.2
[052768ef] CUDA v5.1.1

Hi @samuelfneumann. Thanks for the report! I have a fix for this on #34.