Unsupported `scatter_reduce` operators in Cuda and LLVM backend
WeiPhil opened this issue · 4 comments
Hi,
I've come across two potential issues when performing scatter_reduce
with the dr.ReduceOp.Max
or dr.ReduceOp.Min
operators and the cuda backend, here is a minimal reproducer:
import drjit as dr
from drjit.cuda import Float, UInt32
# from drjit.llvm import Float, UInt32
shape = 8
a = dr.zeros(Float, shape=shape)
b = dr.linspace(Float, start=1.0, stop=8.0, num=shape)
print("a", a)
print("b", b)
idx = dr.arange(UInt32, 0, shape)
print("idx", idx)
dr.scatter_reduce(dr.ReduceOp.Max, a, b, idx)
# dr.scatter_reduce(dr.ReduceOp.Min, a, b, idx)
print("result", a)
With the LLVM backend this prints the expected result
a [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
b [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
idx [0, 1, 2, 3, 4, 5, 6, 7]
result [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
but with the cuda backend I get the following error:
Critical Dr.Jit compiler failure: jit_cuda_compile(): compilation failed. Please see the PTX assembly listing and error message below:
.version 7.1
.target sm_86
.address_size 64
.entry drjit_56971614ee2ebd5608102657da246170(.param .align 8 .b8 params[32]) {
.reg.b8 %b <9>; .reg.b16 %w<9>; .reg.b32 %r<9>;
.reg.b64 %rd<9>; .reg.f32 %f<9>; .reg.f64 %d<9>;
.reg.pred %p <9>;
mov.u32 %r0, %ctaid.x;
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %tid.x;
mad.lo.u32 %r0, %r0, %r1, %r2;
ld.param.u32 %r2, [params];
setp.ge.u32 %p0, %r0, %r2;
@%p0 bra done;
mov.u32 %r3, %nctaid.x;
mul.lo.u32 %r1, %r3, %r1;
body: // sm_86
ld.param.u64 %rd4, [params+8];
ld.param.u64 %rd0, [params+16];
mad.wide.u32 %rd0, %r0, 4, %rd0;
ld.global.cs.f32 %f5, [%rd0];
ld.param.u64 %rd0, [params+24];
mad.wide.u32 %rd0, %r0, 4, %rd0;
ld.global.cs.u32 %r6, [%rd0];
mov.pred %p7, 0x1;
mad.wide.u32 %rd3, %r6, 4, %rd4;
{
.visible .func reduce_max_f32(.param .u64 ptr, .param .f32 value);
call reduce_max_f32, (%rd3, %f5);
}
add.u32 %r0, %r0, %r1;
setp.ge.u32 %p0, %r0, %r2;
@!%p0 bra body;
done:
ret;
}
.visible .func reduce_max_f32(.param .u64 ptr,
.param .f32 value) {
.reg .pred %p<14>;
.reg .f32 %q<19>;
.reg .b32 %r<41>;
.reg .b64 %rd<2>;
ld.param.u64 %rd0, [ptr];
ld.param.f32 %q3, [value];
activemask.b32 %r1;
match.any.sync.b64 %r2, %rd0, %r1;
setp.eq.s32 %p1, %r2, -1;
@%p1 bra.uni fast_path;
brev.b32 %r10, %r2;
bfind.shiftamt.u32 %r40, %r10;
shf.l.wrap.b32 %r12, -2, -2, %r40;
and.b32 %r39, %r2, %r12;
setp.ne.s32 %p2, %r39, 0;
vote.sync.any.pred %p3, %p2, %r1;
@!%p3 bra maybe_scatter;
mov.b32 %r5, %q3;
slow_path_repeat:
brev.b32 %r14, %r39;
bfind.shiftamt.u32 %r15, %r14;
shfl.sync.idx.b32 %r17, %r5, %r15, 31, %r1;
mov.b32 %q6, %r17;
@%p2 max.f32 %q3, %q3, %q6;
shf.l.wrap.b32 %r19, -2, -2, %r15;
and.b32 %r39, %r39, %r19;
setp.ne.s32 %p2, %r39, 0;
vote.sync.any.pred %p3, %p2, %r1;
@!%p3 bra maybe_scatter;
bra.uni slow_path_repeat;
fast_path:
mov.b32 %r22, %q3;
shfl.sync.down.b32 %r26, %r22, 16, 31, %r1;
mov.b32 %q7, %r26;
max.f32 %q8, %q7, %q3;
mov.b32 %r27, %q8;
shfl.sync.down.b32 %r29, %r27, 8, 31, %r1;
mov.b32 %q9, %r29;
max.f32 %q10, %q8, %q9;
mov.b32 %r30, %q10;
shfl.sync.down.b32 %r32, %r30, 4, 31, %r1;
mov.b32 %q11, %r32;
max.f32 %q12, %q10, %q11;
mov.b32 %r33, %q12;
shfl.sync.down.b32 %r34, %r33, 2, 31, %r1;
mov.b32 %q13, %r34;
max.f32 %q14, %q12, %q13;
mov.b32 %r35, %q14;
shfl.sync.down.b32 %r37, %r35, 1, 31, %r1;
mov.b32 %q15, %r37;
max.f32 %q3, %q14, %q15;
mov.u32 %r40, 0;
maybe_scatter:
mov.u32 %r38, %laneid;
setp.ne.s32 %p13, %r40, %r38;
@%p13 bra done;
red.max.f32 [%rd0], %q3;
done:
ret;
}
ptxas application ptx input, line 107; error : Operation .max requires .u32 or .s32 or .u64 or .s64 type for instruction 'red'
ptxas fatal : Ptx assembly aborted due to errors
I'm running on windows with an Nvidia RTX A1000 card and my cuda compiler is the following Cuda compilation tools, release 11.8, V11.8.89, Build cuda_11.8.r11.8/compiler.31833905_0
.
It also seems like dr.ReduceOp.Mul
is not supported on the two backends and fails (on the cuda backend) with :
ptxas application ptx input, line 107; error : Unknown modifier '.mul'
ptxas application ptx input, line 107; error : Illegal operation '' for instruction 'red'
ptxas application ptx input, line 107; error : Operation requires type for instruction 'red'
ptxas application ptx input, line 107; error : Reduction operation is required for instruction 'red'
ptxas fatal : Ptx assembly aborted due to errors
and on the LLVM backend with:
drjit_bac90354b603f82059ca1b030eeec1f7:58:14: error: expected binary operation in atomicrmw
atomicrmw fmul ptr %ptr_0, float %sum monotonic
Are those known limitations/issues of the scatter_reduce
operator?
Best,
Philippe
Hi @WeiPhil
I wasn't aware of these fine details of scatter_reduce
. By digging a bit, here's what I found:
In CUDA:
ReduceOp::Add
: Supported on integer and single-precision floating point typesReduceOp::Mul
: Not supported at allReduceOp::Min
: Only supported on integer typesReduceOp::Max
: Only supported on integer typesReduceOp::And
: Bitwise operation (supports anything)ReduceOp::Or
: Bitwise operation (supports anything)
These are limitations of the red
instruction in PTX (source).
In LLVM (assuming LLVM 16):
ReduceOp::Add
: Supported on integer and floating point typesReduceOp::Mul
: Not supported at allReduceOp::Min
: Supported on integer and floating point typesReduceOp::Max
: Supported on integer and floating point typesReduceOp::And
: Bitwise operation (supports anything)ReduceOp::Or
: Bitwise operation (supports anything)
These are mostly restricted by the atomicrmw
LLVVM IR instruction (source).
This has got me wondering why ReduceOp.Mul
was added...
I'll keep this issue open until I figure out what we actually want to support. At the very least, what you are seeing now is "expected" behavior. I believe we could fully support this set of operations with integer and floating point types but it would require some more work (basically manually add some synchronization points). This was either never done because we have only needed ReduceOp.Add
or because there is some other limitation I'm currently unaware of.
ReduceOp.Mul
is there because the plan was also to use this enumeration internally for horizontal reductions (exposed as drjit.prod
in the upcoming nanobind rewrite). I agree that it's pretty weird for atomics.
I will explain these limitations in the documentation. Can the issue be closed?
Sounds good, thank you!