Initialization of reduction output may need to be predicated
naoyam opened this issue · 0 comments
Currently, an initialization of reduction is not guarded by a thread predicate.
https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/lower_unroll.cpp#L73-L75
const auto thread_pred = isReductionInitExpr(expr)
? GpuLower::current()->kernel()->trueVal()
: GpuLower::current()->threadPredMap().getPredicate(out_tv);
This is not always safe. Example:
auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);
auto tv1 = sum(tv0, {0});
fusion.addOutput(tv1);
auto tv2 = makeSymbolicTensor(1);
fusion.addInput(tv2);
auto tv3 = exp(tv2);
fusion.addOutput(tv3);
tv3->split(0, 32);
tv3->axis(-2)->parallelize(ParallelType::BIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
The code we generate:
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T2, Tensor<float, 0> T1, Tensor<float, 1> T3) {
int64_t i54;
i54 = ((nvfuser_index_t)blockIdx.x) * 32;
bool b64;
b64 = (0 == (-((nvfuser_index_t)blockIdx.x))) && (((nvfuser_index_t)threadIdx.x) == 0);
T1[0] = 0.00000000000000000e+00;
#pragma unroll 1
for(nvfuser_index_t i10 = 0; i10 < T0.size[0]; ++i10) {
if (b64) {
T1[0]
= T1[0]
+ T0[(T0.stride[0] * i10)];
}
}
if ((((nvfuser_index_t)threadIdx.x) < (T2.size[0] - i54))) {
T3[(i54 + ((nvfuser_index_t)threadIdx.x))]
= expf(T2[(((T2.stride[0] * ((nvfuser_index_t)blockIdx.x)) * 32) + (T2.stride[0] * ((nvfuser_index_t)threadIdx.x)))]);
}
}
The issue is the initialization of T1
. T3
is just to force launching multiple thread blocks. Since T1
is on global memory, while the actual reduction is guarded by b64
, which is (0 == (-((nvfuser_index_t)blockIdx.x))) && (((nvfuser_index_t)threadIdx.x) == 0)
, the initialization is not, so it could be overwritten by zero by any other blocks.
We usually don't put reduction output tensors on global memory, so I think that's why we haven't been hit by this simple bug.
The same problem also happens with shared memory. Will be more conservative when omitting thread predicates for buffer initializations.