sx-aurora-dev/VectorEngine.jl

Returning variables passed on stack not working

Closed this issue · 7 comments

On VE we can pass kernel arguments by reference on the kernel's stack. If marked "INOUT" the stack is transferred back to the VH after execution and the original variables on the VH receive the updated values.
This does not work in VectorEngine.jl, mainly because the argument that should be passed back is not regarded as worth storing by LLVM. It might need something like an attribute (sret?).

Example for this is examples/pass_struct_on_stack.jl.

This is a "nice-to-have" feature, other accelerators don't support it.

A workaround for this issue is shown in examples/pass_struct_on_stack2.jl.
The kernel function is just a wrapper for a normal non-inlined function. This function doesn't optimize away the store of the modified variable.

The relevant code is:

mutable struct xm
      x::Int32
      m::Int64
end

function pass_struct!(r::xm)
    r.m = r.m + 1
    return
end

The generated LLVM code for the kernel (@device_code_llm optimize=false @veda pass_struct!(a))

define void @_Z23julia_pass_struct__21542xm({}* nonnull byval({}) align 8 dereferenceable(16) %0) local_unnamed_addr #0 {
top:
  %r = alloca {}*, align 8
  %1 = call {}*** @julia.get_pgcstack()
  store {}* null, {}** %r, align 8
  %2 = bitcast {}*** %1 to {}**
  %current_task = getelementptr inbounds {}*, {}** %2, i64 -12
  %3 = bitcast {}** %current_task to i64*
  %world_age = getelementptr inbounds i64, i64* %3, i64 13
  store {}* %0, {}** %r, align 8
;  @ REPL[6]:2 within `pass_struct!`
; ┌ @ Base.jl:38 within `getproperty`
   %4 = load {}*, {}** %r, align 8
   %5 = bitcast {}* %4 to i8*
   %6 = getelementptr inbounds i8, i8* %5, i64 8
   %7 = bitcast i8* %6 to i64*
   %8 = load i64, i64* %7, align 8
; └
; ┌ @ int.jl:87 within `+`
   %9 = add i64 %8, 1
; └
; ┌ @ Base.jl:39 within `setproperty!`
   %10 = load {}*, {}** %r, align 8
   %11 = bitcast {}* %10 to i8*
   %12 = getelementptr inbounds i8, i8* %11, i64 8
   %13 = bitcast i8* %12 to i64*
   store i64 %9, i64* %13, align 8
; └
;  @ REPL[6]:3 within `pass_struct!`
  ret void
}

Notable: the argument is passed byval, which makes the compiler simply throw away any change to the argument.

julia> @device_code_llvm @veda pass_struct!(a)
; CompilerJob of kernel pass_struct!(xm) for GPUCompiler.VECompilerTarget
;  @ REPL[6]:1 within `pass_struct!`
define void @_Z23julia_pass_struct__21832xm({}* nonnull byval({}) align 8 dereferenceable(16) %0) local_unnamed_addr #0 {
top:
;  @ REPL[6]:3 within `pass_struct!`
  ret void
}

So there is https://github.com/JuliaGPU/GPUCompiler.jl/blob/85fa183ebbede40a03c8ac35cfb309a165acaeec/src/gcn.jl#L63-L66

and https://github.com/JuliaGPU/GPUCompiler.jl/blob/3e9b441259c69026b000224639d1a003d9908ac3/src/interface.jl#L186

Julia proper:

julia> @code_llvm optimize=false pass_struct!(xm(1, 1))
;  @ REPL[2]:1 within `pass_struct!`
define nonnull {}* @"japi1_pass_struct!_190"({}* %0, {}** %1, i32 %2) #0 {
top:
  %3 = alloca {}**, align 8
  store volatile {}** %1, {}*** %3, align 8
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %current_task = getelementptr inbounds {}*, {}** %5, i64 2305843009213693940
  %6 = bitcast {}** %current_task to i64*
  %world_age = getelementptr inbounds i64, i64* %6, i64 13
  %7 = getelementptr inbounds {}*, {}** %1, i64 0
  %8 = load {}*, {}** %7, align 8
;  @ REPL[2]:2 within `pass_struct!`
; ┌ @ Base.jl:42 within `getproperty`
   %9 = bitcast {}* %8 to i8*
   %10 = getelementptr inbounds i8, i8* %9, i64 8
   %11 = bitcast i8* %10 to i64*
   %12 = load i64, i64* %11, align 8
; └
; ┌ @ int.jl:87 within `+`
   %13 = add i64 %12, 1
; └
; ┌ @ Base.jl:43 within `setproperty!`
   %14 = bitcast {}* %8 to i8*
   %15 = getelementptr inbounds i8, i8* %14, i64 8
   %16 = bitcast i8* %15 to i64*
   store i64 %13, i64* %16, align 8
   call void ({}*, ...) @julia.write_barrier({}* %8)
; └
;  @ REPL[2]:3 within `pass_struct!`
  ret {}* inttoptr (i64 139781959577608 to {}*)
}

So this is not something that GPUCompiler actively supports. I think Enzyme.jl is getting around it by never compiling anything as a kernel. cc @maleadt

So you could try setting true->false in

source = FunctionSpec(f, tt, true, name)

and https://github.com/JuliaGPU/GPUCompiler.jl/blob/3e9b441259c69026b000224639d1a003d9908ac3/src/interface.jl#L186

That logic should only set byval for bitstypes. So this should work for mutable types. I'm not sure where that byval comes from, but it shouldn't warrant setting kernel=true (which disables other things, like validation, kernel state, etc).

I left kernel=true but replaced process_entry!(). Closing the issue for now, although I'm not 100% happy with the way it is solved.

replaced process_entry!()

Again, the existing logic should work for your case as you describe it (only touch bitstypes, don't touch mutable ones), so if it fails there's presumably a bug in GPUCompiler we should fix instead of disabling that logic here.