JuliaGPU/Metal.jl

Port `accmulate!` and `findall` from CUDA.jl

Closed this issue · 8 comments

Hi there, this is a really great package. Thanks for your great efforts!

And my questions is:
Is there any plan to support functions like accumulate!, cumsum and findall or similar in Metal.jl like CUDA.jl?
Or could you point me to any resources that I can follow to write these functions in Metal.jl?

Thanks!

Hi zhenwu0728, thank you for your interest!

The metal performance shaders library seems to provide functionality for cumsum. I have a local branch with the library wrappers implemented, but I won't be returning to it until I finish my thesis, so feel free to attempt it by yourself. There are a few recent PRs that you could use as a template.

Another option for the cumsum and the others is to adapt the CUDA.jl implementations for Metal.jl.

Another option for the cumsum and the others is to adapt the CUDA.jl implementations for Metal.jl.

Or, alternatively, port them to GPUArrays.jl which would make them available to all back-ends.

Or, alternatively, port them to GPUArrays.jl which would make them available to all back-ends.

@maleadt Any ideas when this will happen? I've seen similar codes for accumulate in CUDA.jl and AMDGPU.jl.

Any ideas when this will happen?

Probably only after GPUArrays.jl migrates to KernelAbstractions.jl, which is still some weeks-months off. So feel free to take a stab at a native (i.e. in Metal.jl) implementation first if you require this functionality.

If anyone else encounters this problem and really needs cumsum, I implemented a very simple version using GPUArrays.jl that does cumsum(x; dims=2) (the only version I needed, but could be easily modified for dims=1):

# cumsum(x; dims=2), thanks to https://pde-on-gpu.vaw.ethz.ch/lecture10/
function cumsum2(A::AnyGPUMatrix{T}) where {T}
    B = similar(A)
    gpu_call(B, A; name="cumsum!", elements=size(A, 1)) do ctx, B, A
        idx = @cartesianidx B
        i, j = Tuple(idx)
        cur_val = zero(T)
        for k in 1:size(A, 2)
            @inbounds cur_val += A[i, k]
            @inbounds B[i, k] = cur_val
        end
        return
    end
    return B
end
# Potential improvements: use shared memory and block/grid sizes

At least in my tests, it seems to be comparable to CUDA's cumsum and multiple times faster than doing gpu(cumsum(cpu(x))).

# Benchmarking
using BenchmarkTools
b1 = @benchmark CUDA.@sync cumsum2($A)                  # My cumsum
b2 = @benchmark CUDA.@sync cumsum($A; dims=2)           # CUDA
b3 = @benchmark CUDA.@sync gpu(cumsum(cpu($A); dims=2)) # Current workaround
speedup_compared_to_cuda = mean(b2.times) / mean(b1.times)
speedup_compared_to_current = mean(b3.times) / mean(b1.times)

# A = randn(Float32, 101, 8) |> gpu
# speedup_compared_to_cuda = 1.1598880188052891
# speedup_compared_to_current = 1.7993942371142466

# A = randn(Float32, 10_001, 800) |> gpu
# speedup_compared_to_cuda = 3.420802824285882
# speedup_compared_to_current = 89.42437204572116

This is also an issue for oneAPI.jl, but the GPUArrays.jl solution should apply to both.

Resolved by #377 and #382.

Can you register a new version of the package? I would like to use these two functions. Thanks!