rachtsingh/lgamma

Test fails for FloatTensor

Opened this issue · 5 comments

Awesome project, but here's a bug! If I change the TensorType to torch.cuda.FloatTensor in test.py then the gradient check test fails for me (works fine w/ DoubleTensor). So it seems like the implementation is correct only for the DoubleTensor type.

The dependence on DoubleTensor makes it difficult to use, particularly when the beta, lgamma, and digamma can very quickly run into numerical stability issues with the float-type bounds.

Also, the test fails if you change the dimensions of the input tensors -- for example, a 50x100 tensor.

I managed to track things down to the point where the calls to the beta or digamma functions needs to be contiguous, but there is still an issue where some elements of the tensor are not updated (due to some shape/stride issues) which means that you can often end up with infs or nans.

Edit: this only fails for the cuda versions. The cpu DoubleTensor version appears to work fine.
For the cuda version, this is the error I get when running the test (with tensors of size 50x100)

THCudaCheck FAIL file=/<...>/pytorch-src/torch/lib/THC/generic/THCTensorCopy.c line=65 error=77 : an illegal memory access was encountered
E
======================================================================
ERROR: test_many_times (__main__.TestBetaGrads)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test.py", line 21, in test_many_times
    result = gradcheck(Beta(), (a, b), eps=1e-6, atol=1e-3)
  File "/<...>/.virtualenvs/pytorch/lib/python2.7/site-packages/torch/autograd/gradcheck.py", line 143, in gradcheck
    numerical = get_numerical_jacobian(fn, inputs, inputs, eps)
  File "/<...>/.virtualenvs/pytorch/lib/python2.7/site-packages/torch/autograd/gradcheck.py", line 73, in get_numerical_jacobian
    outa.copy_(fn(input))
RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /<...>/pytorch-src/torch/lib/THC/generic/THCTensorCopy.c:65

----------------------------------------------------------------------
Ran 1 test in 1.441s

FAILED (errors=1)

Edit2: I managed to track down the error with the shape/stride. The issue was the addressing in computing the functions when the cuda block traversal already provided a simple means for addressing, which the pure C version didn't have.

# in the functions_cuda_kernel.cu file for the different kernels
# old
int index = threadIdx.x;
int stride = blockDim.x;
for (int i = index; i < width * height; i += stride) {
  int x = i / width;
  int y = i % width;
  int out_address = x * output_swidth + y * output_sheight;
  int in_address = x * input_swidth + y * input_sheight;
  output_data[out_address] = polygamma_impl_dbl(n, input_data[in_address]);
}

# new
for (int addr = threadIdx.x; addr < width * height; addr += blockDim.x)
    output_data[addr] = polygamma_impl_dbl(n, input_data[addr]);

This fixes the errors for arbitrary sized tensors and nan/inf issues with arbitrarily initialised values not being updated.

So, I initially thought the issue was with the polygamma and zeta implementations, but it doesn't look likely now. I tried out the cpu versions (torch.FloatTensor, and torch.DoubleTensor) with the GSL polygamma implementation from here, and the float version still fails. I suspect it might even be something of an issue in pytorch's gradcheck given that it seems to be expecting DoubleTensors, at least with the numerical Jacobian construction here.

Awesome! Thanks for noting the issue @bogdanstate and fixing the CUDA striding issue @iffsid (CUDA is more well designed than I realized). I've pushed a fix for striding, but the FloatTensor issue is more difficult. I tried changing the implementation of gradcheck to use type(input[0].data)(output_size) to see if that was the issue, but it doesn't help. Also, you might note that I had to increase atol to 1e-3 to get the gradcheck passing, which is less than ideal.

I suspect the issue is more inherent, and that because of high curvature of polygamma, calculating the gradient to float precision requires higher than float precision of inputs. I think that the polygamma/zeta implementations aren't ideal, either.

My earliest version just copied to CPU/numpy and used SciPy's implementations to calculate gradients, so I can compare with that. Unfortunately I'm in NY this summer and GPU access is difficult, so it might be a few weeks until I can look at it in more detail.

Cool. One thing to note is that pytorch now (master branch on github) natively provides the lgamma function for cuda. The only thing missing now is the polygamma function.