csarofeen/pytorch

Python where op floatxfloat promotes to float64

Closed this issue · 7 comments

pred = make_tensor((5,), device='cuda', dtype=torch.bool)

fs = Fusion()
with FusionDefinition(fs) as fd:
    nv_pred = fd.define_tensor(sizes=pred.shape, strides=pred.stride(), dtype=DataType.Bool)
    five = fd.define_constant(5.)
    three = fd.define_constant(3.)

    result = fd.ops.where(nv_pred, five, three)

    fd.add_output(result)


nv_result = fs.execute((pred,))[0]
print(f"nv_result={nv_result}")
: nv_result=tensor([5., 3., 3., 3., 5.], device='cuda:0', dtype=torch.float64)

torch_result = torch.where(pred, 5., 3.)
print(f"torch_result={torch_result}")
: torch_result=tensor([5., 3., 3., 3., 5.], device='cuda:0')

The solution is to change the define_constant API to allow for type specification since python numbers are inferred to be double for floating point.

And define_scalar, too?

define_scalar already allows you to specify a type. I think we are okay there unless you saw an issue?

define_scalar already allows you to specify a type. I think we are okay there unless you saw an issue?

My mistake, I didn't realize

We currently don't have single-precision scalars, neither in Python nor C++. I hit some errors trying to add those (see #2403), but it could probably be done. However, it may be simpler to add a dtype to the where op that defaults to DataType.Float. The effect of the argument would be to insert a cast op after where.

@mruberry the following lines in the PR branch above allow you to force the constant DataTypes to Float, which for where is enough to ensure a float32-valued output:

c0f = fd.define_constant(3.0, DataType.Float)
c1f = fd.define_constant(5.0, DataType.Float)
t1f = fd.ops.where(t0, c0f, c1f) # DataType.Float
fd.add_output(t1f)

Would that sufficiently address this issue? Note that we haven't changed the promotion rules for nvfuser: if an op receives only scalar floating point arguments, we do not use default floating type as is done in pytorch, but rather the highest-precision type of the given arguments.

@mruberry the following lines in the PR branch above allow you to force the constant DataTypes to Float, which for where is enough to ensure a float32-valued output:

c0f = fd.define_constant(3.0, DataType.Float)
c1f = fd.define_constant(5.0, DataType.Float)
t1f = fd.ops.where(t0, c0f, c1f) # DataType.Float
fd.add_output(t1f)

Would that sufficiently address this issue? Note that we haven't changed the promotion rules for nvfuser: if an op receives only scalar floating point arguments, we do not use default floating type as is done in pytorch, but rather the highest-precision type of the given arguments.

Yes, I think that would address the issue