`CuArray` has no field storage (PyCall, CUDA, Jax)
samuelfneumann opened this issue · 1 comments
Hi,
I'm trying to convert a Julia CuArray
to a Jax array using DLPack.jl. Here is a minimal working example of what I am doing:
using DLPack, CUDA, PyCall
dl = pyimport("jax.dlpack")
to_jax(o) = DLPack.share(o, dl.from_dlpack)
arr = cu(rand(10))
arr |> to_jax
When I run exactly this code, I get the following error message
ERROR: type CuArray has no field storage
Stacktrace:
[1] getproperty
@ ./Base.jl:37 [inlined]
[2] dldevice(B::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ DLPack ~/.julia/packages/DLPack/SUhao/src/cuda.jl:12
[3] unsafe_share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ DLPack ~/.julia/packages/DLPack/SUhao/src/DLPack.jl:253
[4] share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ DLPack ~/.julia/packages/DLPack/SUhao/src/cuda.jl:4
[5] share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, #unused#::Type{PyObject}, from_dlpack
::PyObject)
@ DLPack ~/.julia/packages/DLPack/SUhao/src/pycall.jl:99
[6] share(A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, from_dlpack::PyObject)
@ DLPack ~/.julia/packages/DLPack/SUhao/src/pycall.jl:89
[7] to_jax(o::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Main ./REPL[3]:1
[8] |>(x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, f::typeof(to_jax))
@ Base ./operators.jl:907
[9] top-level scope
@ REPL[10]:1
[10] top-level scope
@ ~/.julia/packages/CUDA/YIj5X/src/initialization.jl:208
It appears that the error originates in line 12 of src/cuda.jl
. After some fiddling around, I was able to successfully convert the CuArray
to Jax when I overrode the dldevice(B::CUDA.StridedCuArray)
function with:
buftype(::CuArray{F,I,B}) where {F,I,B} = B
function DLPack.dldevice(B::CUDA.StridedCuArray)
buf = buftype(B)
dldt = if buf isa Type{CUDA.Mem.DeviceBuffer}
DLPack.kDLCUDA
elseif buf isa Type{CUDA.Mem.HostBuffer}
DLPack.kDLCUDAHost
elseif buf isa Type{CUDA.Mem.UnifiedBuffer}
DLPack.kDLCUDAManaged
end
return DLPack.DLDevice(dldt, CUDA.device(B))
end
Looking through the old versions of CUDA.jl
it appears that the storage
field was removed in the CuArray
, which explains where this error is coming from. I'm not completely certain that this is correct though, but maybe someone can point me in the right direction, and whether or not I should open a PR.
I'm using the following package versions
[438e738f] PyCall v1.96.3
[53c2dc0f] DLPack v0.1.2
[052768ef] CUDA v5.1.1
Hi @samuelfneumann. Thanks for the report! I have a fix for this on #34.