JuliaGPU/Metal.jl

Unable to compile trig functions through ForwardDiff

Closed this issue · 4 comments

MWE:

using Metal, KernelAbstractions, ForwardDiff

X = Metal.MtlArray(fill(0.1f0, 128))
Y = copy(X)

@kernel function mwe_kernel(out, a)
    I = @index(Global, Linear)
    out[I] = ForwardDiff.derivative(sin, a[I])
end

kernel = mwe_kernel(Metal.MetalBackend())
kernel(Y, X, ndrange = size(Y))

Output:

ERROR: InvalidIRError: compiling MethodInstance for gpu_mwe_kernel(::KernelAbstractions.CompilerMetadata{KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicCheck, Nothing, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}}}, ::MtlDeviceVector{Float32, 1}, ::MtlDeviceVector{Float32, 1}) resulted in invalid LLVM IR
Reason: unsupported call to an unknown function (call to gpu_malloc)
Stacktrace:
  [1] malloc
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/runtime.jl:88
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/runtime.jl:183
  [3] macro expansion
    @ ./none:0
  [4] box
    @ ./none:0
  [5] box_int64
    @ ~/.julia/packages/GPUCompiler/YO8Uj/src/runtime.jl:212
  [6] indexed_iterate
    @ ./tuple.jl:97
  [7] sin
    @ ~/Developer/jl-forward-diff/dev/ForwardDiff/src/dual.jl:697
  [8] derivative
    @ ~/Developer/jl-forward-diff/dev/ForwardDiff/src/derivative.jl:14
  [9] macro expansion
    @ ~/Developer/jl-forward-diff/mwe.jl:8
 [10] gpu_mwe_kernel
    @ ~/.julia/packages/KernelAbstractions/cWlFz/src/macros.jl:90
 [11] gpu_mwe_kernel
    @ ./none:0

I don't really know the first thing about how the GPU compiler works (though am keen to learn), but I had a look with Cthulhu.jl all the same. I noticed using @device_code_llvm that when I have just sincos(a[I])[1] (without the ForwardDiff.derivative) call, then I get

%29 = call float @air.sincos.f32(float %28)

But with ForwardDiff.derivative(sin, a[I]) and code_typed(err; interactive = true) , I see:

%3 = call [2 x float] @j_sincos_4567(float %2) #0

which to me looks like it's trying to call the Julia sincos function instead of the Metal / air one?

I don't know how to diagnose further / how one would fix this, but am happy to help in any way I can!

Status `~/Developer/jl-forward-diff/Project.toml`
  [f6369f11] ForwardDiff v0.10.35
  [63c18a36] KernelAbstractions v0.9.8
  [dde4c033] Metal v0.5.0

Julia Version 1.9.2
Commit e4ee485e909 (2023-07-05 09:39 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, apple-m1)
  Threads: 12 on 6 virtual cores
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 

Edit: it's specifically only the sincos function when called in the ForwardDiff sin function. If I modify ForwardDiff to just do sin and cos seperately, it all works fine.

I guess this is essentially a dup of #69. The question is why an exception is being generated here, as we do support sincos:

@device_override FastMath.sincos_fast(x::Float32) = ccall("extern air.fast_sincos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sincos(x::Float32) = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sincos(x::Float16) = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16,), x)

Can you see using Cthulhu how sincos is invoked?

Ah, sorry I didn't notice this the first time. It turns out its related to the indexed_iterate:

New MWE:

using Metal, KernelAbstractions

X = Metal.MtlArray(fill(0.3f0, 128))
Y = copy(X)

@kernel function mwe_kernel_sincos(out, a)
    I = @index(Global, Linear)
    s, c = sincos(a[I])
    out[I] = s + c
end

kernel = mwe_kernel_sincos(Metal.MetalBackend())
kernel(Y, X, ndrange = size(Y))
ERROR: InvalidIRError: compiling MethodInstance for gpu_mwe_kernel_sincos(::KernelAbstractions.CompilerMetadata{KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicCheck, Nothing, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}}}, ::MtlDeviceVector{Float32, 1}, ::MtlDeviceVector{Float32, 1}) resulted in invalid LLVM IR
Reason: unsupported call to an unknown function (call to gpu_malloc)
Stacktrace:
 [1] malloc
   @ ~/.julia/packages/GPUCompiler/YO8Uj/src/runtime.jl:88
 [2] macro expansion
   @ ~/.julia/packages/GPUCompiler/YO8Uj/src/runtime.jl:183
 [3] macro expansion
   @ ./none:0
 [4] box
   @ ./none:0
 [5] box_int64
   @ ~/.julia/packages/GPUCompiler/YO8Uj/src/runtime.jl:212
 [6] indexed_iterate                                                            <--
   @ ./tuple.jl:97
 [7] macro expansion
   @ ~/Developer/jl-forward-diff/mwe.jl:9
 [8] gpu_mwe_kernel_sincos
   @ ~/.julia/packages/KernelAbstractions/cWlFz/src/macros.jl:90
 [9] gpu_mwe_kernel_sincos
   @ ./none:0

Throws ostensibly the same error as above. So, instead trying

@kernel function mwe_kernel_sincos(out, a)
    I = @index(Global, Linear)
    k = sincos(a[I])
    out[I] = k[1] + k[2]
end

Throws no error but only has k[1] non-zero -- that is, k[2] doesn't have a value at all?

It seems to me that the Metal sincos only returns a single float, which is the sin part? @*code_warntype confirms this with the external calls?

Edit: some examples

Kernel:

@kernel function mwe_kernel_sincos(out, a)
    I = @index(Global, Linear)
    k = sincos(a[I])
    out[I] = k[1]
end
41%119 = $(Expr(:foreigncall, "extern air.sincos.f32", Float32, svec(Float32), 0, :(:llvmcall), :(%116), :(%116)))::Float32
└───        goto #43 if not true
42nothing::Nothing
43 ┄        goto #44
44 ─        goto #49 if not true
45%124 = Core.tuple(%92)::Tuple{UInt32}%125 = Base.getfield(out, :shape)::Tuple{Int64}%126 = Base.getfield(%125, 1, true)::Int64

Kernel:

@kernel function mwe_kernel_sincos(out, a)
    I = @index(Global, Linear)
    k = sincos(a[I])
    out[I] = k[2]
end
41%119 = $(Expr(:foreigncall, "extern air.sincos.f32", Float32, svec(Float32), 0, :(:llvmcall), :(%116), :(%116)))::Float32
└───        goto #43 if not true
42 ─        Metal.throw(Metal.nothing)::Union{}
└───        unreachable
43 ─        goto #44
44 ─        goto #49 if not true
45%125 = Core.tuple(%92)::Tuple{UInt32}%126 = Base.getfield(out, :shape)::Tuple{Int64}%127 = Base.getfield(%126, 1, true)::Int64

From the Metal developer API:

Screenshot 2023-08-01 at 13 36 13

So changing

@device_override function Base.sincos(x::Float32) 
    c = Ref{Cfloat}()
    s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
    (s, c[])
end

fixes everything.

I will open a PR with the fixes :)