csarofeen/pytorch

Initialization of reduction output may need to be predicated

naoyam opened this issue · 0 comments

Currently, an initialization of reduction is not guarded by a thread predicate.

https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/lower_unroll.cpp#L73-L75

const auto thread_pred = isReductionInitExpr(expr)
        ? GpuLower::current()->kernel()->trueVal()
        : GpuLower::current()->threadPredMap().getPredicate(out_tv);

This is not always safe. Example:

  auto tv0 = makeSymbolicTensor(1);
  fusion.addInput(tv0);

  auto tv1 = sum(tv0, {0});
  fusion.addOutput(tv1);

  auto tv2 = makeSymbolicTensor(1);
  fusion.addInput(tv2);

  auto tv3 = exp(tv2);
  fusion.addOutput(tv3);

  tv3->split(0, 32);
  tv3->axis(-2)->parallelize(ParallelType::BIDx);
  tv3->axis(-1)->parallelize(ParallelType::TIDx);

The code we generate:

__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T2, Tensor<float, 0> T1, Tensor<float, 1> T3) {
  int64_t i54;
  i54 = ((nvfuser_index_t)blockIdx.x) * 32;
  bool b64;
  b64 = (0 == (-((nvfuser_index_t)blockIdx.x))) && (((nvfuser_index_t)threadIdx.x) == 0);
  T1[0] = 0.00000000000000000e+00;
  #pragma unroll 1
  for(nvfuser_index_t i10 = 0; i10 < T0.size[0]; ++i10) {
    if (b64) {
      T1[0]
        = T1[0]
        + T0[(T0.stride[0] * i10)];
    }
  }
  if ((((nvfuser_index_t)threadIdx.x) < (T2.size[0] - i54))) {
    T3[(i54 + ((nvfuser_index_t)threadIdx.x))]
       = expf(T2[(((T2.stride[0] * ((nvfuser_index_t)blockIdx.x)) * 32) + (T2.stride[0] * ((nvfuser_index_t)threadIdx.x)))]);
  }
}

The issue is the initialization of T1. T3 is just to force launching multiple thread blocks. Since T1 is on global memory, while the actual reduction is guarded by b64, which is (0 == (-((nvfuser_index_t)blockIdx.x))) && (((nvfuser_index_t)threadIdx.x) == 0), the initialization is not, so it could be overwritten by zero by any other blocks.

We usually don't put reduction output tensors on global memory, so I think that's why we haven't been hit by this simple bug.

The same problem also happens with shared memory. Will be more conservative when omitting thread predicates for buffer initializations.