mitsuba-renderer/drjit

Unsupported `scatter_reduce` operators in Cuda and LLVM backend

WeiPhil opened this issue · 4 comments

Hi,
I've come across two potential issues when performing scatter_reduce with the dr.ReduceOp.Max or dr.ReduceOp.Min operators and the cuda backend, here is a minimal reproducer:

import drjit as dr

from drjit.cuda import Float, UInt32
# from drjit.llvm import Float, UInt32

shape = 8
a = dr.zeros(Float, shape=shape)
b = dr.linspace(Float, start=1.0, stop=8.0, num=shape)

print("a", a)
print("b", b)
idx = dr.arange(UInt32, 0, shape)
print("idx", idx)

dr.scatter_reduce(dr.ReduceOp.Max, a, b, idx)
# dr.scatter_reduce(dr.ReduceOp.Min, a, b, idx)
print("result", a)

With the LLVM backend this prints the expected result

a [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
b [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
idx [0, 1, 2, 3, 4, 5, 6, 7]
result [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]

but with the cuda backend I get the following error:

Critical Dr.Jit compiler failure: jit_cuda_compile(): compilation failed. Please see the PTX assembly listing and error message below:

.version 7.1
.target sm_86
.address_size 64

.entry drjit_56971614ee2ebd5608102657da246170(.param .align 8 .b8 params[32]) {
    .reg.b8   %b <9>; .reg.b16 %w<9>; .reg.b32 %r<9>;
    .reg.b64  %rd<9>; .reg.f32 %f<9>; .reg.f64 %d<9>;
    .reg.pred %p <9>;

    mov.u32 %r0, %ctaid.x;
    mov.u32 %r1, %ntid.x;
    mov.u32 %r2, %tid.x;
    mad.lo.u32 %r0, %r0, %r1, %r2;
    ld.param.u32 %r2, [params];
    setp.ge.u32 %p0, %r0, %r2;
    @%p0 bra done;

    mov.u32 %r3, %nctaid.x;
    mul.lo.u32 %r1, %r3, %r1;

body: // sm_86
    ld.param.u64 %rd4, [params+8];
    ld.param.u64 %rd0, [params+16];
    mad.wide.u32 %rd0, %r0, 4, %rd0;
    ld.global.cs.f32 %f5, [%rd0];
    ld.param.u64 %rd0, [params+24];
    mad.wide.u32 %rd0, %r0, 4, %rd0;
    ld.global.cs.u32 %r6, [%rd0];
    mov.pred %p7, 0x1;
    mad.wide.u32 %rd3, %r6, 4, %rd4;
    {
        .visible .func reduce_max_f32(.param .u64 ptr, .param .f32 value);
        call reduce_max_f32, (%rd3, %f5);
    }

    add.u32 %r0, %r0, %r1;
    setp.ge.u32 %p0, %r0, %r2;
    @!%p0 bra body;

done:
    ret;
}

.visible .func reduce_max_f32(.param .u64 ptr,
                              .param .f32 value) {
    .reg .pred %p<14>;
    .reg .f32 %q<19>;
    .reg .b32 %r<41>;
    .reg .b64 %rd<2>;

    ld.param.u64 %rd0, [ptr];
    ld.param.f32 %q3, [value];
    activemask.b32 %r1;
    match.any.sync.b64 %r2, %rd0, %r1;
    setp.eq.s32 %p1, %r2, -1;
    @%p1 bra.uni fast_path;

    brev.b32 %r10, %r2;
    bfind.shiftamt.u32 %r40, %r10;
    shf.l.wrap.b32 %r12, -2, -2, %r40;
    and.b32 %r39, %r2, %r12;
    setp.ne.s32 %p2, %r39, 0;
    vote.sync.any.pred %p3, %p2, %r1;
    @!%p3 bra maybe_scatter;
    mov.b32 %r5, %q3;

slow_path_repeat:
    brev.b32 %r14, %r39;
    bfind.shiftamt.u32 %r15, %r14;
    shfl.sync.idx.b32 %r17, %r5, %r15, 31, %r1;
    mov.b32 %q6, %r17;
    @%p2 max.f32 %q3, %q3, %q6;
    shf.l.wrap.b32 %r19, -2, -2, %r15;
    and.b32 %r39, %r39, %r19;
    setp.ne.s32 %p2, %r39, 0;
    vote.sync.any.pred %p3, %p2, %r1;
    @!%p3 bra maybe_scatter;
    bra.uni slow_path_repeat;

fast_path:
    mov.b32 %r22, %q3;
    shfl.sync.down.b32 %r26, %r22, 16, 31, %r1;
    mov.b32 %q7, %r26;
    max.f32 %q8, %q7, %q3;
    mov.b32 %r27, %q8;
    shfl.sync.down.b32 %r29, %r27, 8, 31, %r1;
    mov.b32 %q9, %r29;
    max.f32 %q10, %q8, %q9;
    mov.b32 %r30, %q10;
    shfl.sync.down.b32 %r32, %r30, 4, 31, %r1;
    mov.b32 %q11, %r32;
    max.f32 %q12, %q10, %q11;
    mov.b32 %r33, %q12;
    shfl.sync.down.b32 %r34, %r33, 2, 31, %r1;
    mov.b32 %q13, %r34;
    max.f32 %q14, %q12, %q13;
    mov.b32 %r35, %q14;
    shfl.sync.down.b32 %r37, %r35, 1, 31, %r1;
    mov.b32 %q15, %r37;
    max.f32 %q3, %q14, %q15;
    mov.u32 %r40, 0;

maybe_scatter:
    mov.u32 %r38, %laneid;
    setp.ne.s32 %p13, %r40, %r38;
    @%p13 bra done;
    red.max.f32 [%rd0], %q3;

done:
    ret;
}


ptxas application ptx input, line 107; error   : Operation .max requires .u32 or .s32 or .u64 or .s64 type for instruction 'red'
ptxas fatal   : Ptx assembly aborted due to errors

I'm running on windows with an Nvidia RTX A1000 card and my cuda compiler is the following Cuda compilation tools, release 11.8, V11.8.89, Build cuda_11.8.r11.8/compiler.31833905_0.

It also seems like dr.ReduceOp.Mul is not supported on the two backends and fails (on the cuda backend) with :

ptxas application ptx input, line 107; error   : Unknown modifier '.mul'
ptxas application ptx input, line 107; error   : Illegal operation '' for instruction 'red'
ptxas application ptx input, line 107; error   : Operation  requires  type for instruction 'red'
ptxas application ptx input, line 107; error   : Reduction operation is required for instruction 'red'
ptxas fatal   : Ptx assembly aborted due to errors

and on the LLVM backend with:

drjit_bac90354b603f82059ca1b030eeec1f7:58:14: error: expected binary operation in atomicrmw
   atomicrmw fmul ptr %ptr_0, float %sum monotonic

Are those known limitations/issues of the scatter_reduce operator?

Best,
Philippe

Hi @WeiPhil

I wasn't aware of these fine details of scatter_reduce. By digging a bit, here's what I found:

In CUDA:

  • ReduceOp::Add: Supported on integer and single-precision floating point types
  • ReduceOp::Mul: Not supported at all
  • ReduceOp::Min: Only supported on integer types
  • ReduceOp::Max: Only supported on integer types
  • ReduceOp::And: Bitwise operation (supports anything)
  • ReduceOp::Or: Bitwise operation (supports anything)

These are limitations of the red instruction in PTX (source).

In LLVM (assuming LLVM 16):

  • ReduceOp::Add: Supported on integer and floating point types
  • ReduceOp::Mul: Not supported at all
  • ReduceOp::Min: Supported on integer and floating point types
  • ReduceOp::Max: Supported on integer and floating point types
  • ReduceOp::And: Bitwise operation (supports anything)
  • ReduceOp::Or: Bitwise operation (supports anything)

These are mostly restricted by the atomicrmw LLVVM IR instruction (source).

This has got me wondering why ReduceOp.Mul was added...

I'll keep this issue open until I figure out what we actually want to support. At the very least, what you are seeing now is "expected" behavior. I believe we could fully support this set of operations with integer and floating point types but it would require some more work (basically manually add some synchronization points). This was either never done because we have only needed ReduceOp.Add or because there is some other limitation I'm currently unaware of.

wjakob commented

ReduceOp.Mul is there because the plan was also to use this enumeration internally for horizontal reductions (exposed as drjit.prod in the upcoming nanobind rewrite). I agree that it's pretty weird for atomics.

I will explain these limitations in the documentation. Can the issue be closed?

Sounds good, thank you!