zdevito/pytorch

Sort out whatever is happening with bernoulli, and regularize

zdevito opened this issue · 1 comments

It is the only thing that has THFloatTensor THDoubleTensor referred to directly and it is not clear what is even happening in this file.

#define THCudaDoubleTensor_BERNOULLI_TENSOR THCudaDoubleTensor_bernoulli_DoubleTensor
#define THCudaTensor_BERNOULLI_TENSOR THCudaTensor_bernoulli_FloatTensor

[[
  name: bernoulli
  defined_if: CUDA_FLOAT || CUDA_DOUBLE
  types:
    - Float
    - Double
  processors:
    - CUDA
  return: argument 0
  variants:
    - method
    - function
  cname: BERNOULLI_TENSOR
  before_call:
    THTensor_(resizeAs)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, ((THPTensor*)$arg1)->cdata);
  arguments:
    - arg: THTensor* output
      output: True
    - THTensor* self
]]

#undef THCudaDoubleTensor_BERNOULLI_TENSOR
#undef THCudaTensor_BERNOULLI_TENSOR

[[
  name: bernoulli_
  defined_if: CUDA_FLOAT || CUDA_DOUBLE || CUDA_HALF
  types:
    - floating_point
  processors:
    - CUDA
  return: self
  options:
    - cname: bernoulli
      arguments:
        - THTensor* self
        - arg: double p
          default: 0.5
    - cname: bernoulli_FloatTensor
      arguments:
        - THTensor* self
        - THCudaTensor* float_p
    - cname: bernoulli_DoubleTensor
      arguments:
        - THTensor* self
        - THCudaDoubleTensor* float_p
]]

@zdevito the bernoulli function has three variants:

x = torch.Tensor(10)
x.bernoulli_(0.5) # sample from bernoulli distribution with p=0.5 over all elements in x

y_float = torch.Tensor(10).uniform_()
x.bernoulli_(y) # sample from bernoulli distribution with p sourced from the corresponding value in y at each index

y_double = torch.DoubleTensor(10).uniform_()
x.bernoulli_(y) # variant of above, with double precision tensor source for p

The 2nd and 3rd variants explicitly have float and double Tensor arguments.

Now, for whatever reason (I didn't look into why), consider the following:

x = torch.DoubleTensor(10)
y = torch.Tensor(10)
z = torch.Tensor(10).uniform_()

torch.bernoulli(z, out=y) # ok!
torch.bernoulli(z, out=x) # invalid args, z/x must have same type

So essentially what the macros are doing is making it so that the functions generated for the first declaration call the float tensor bernoulli function when the type is CUDA_FLOAT, and the double tensor bernoulli function when the type is CUDA_DOUBLE. In the second declaration, where we don't have an output, this works. Its unclear to me without further digging why this is so, it seems that we should be able to have any output type Tensor for bernoulli...