Python where op floatxfloat promotes to float64
Closed this issue · 7 comments
pred = make_tensor((5,), device='cuda', dtype=torch.bool)
fs = Fusion()
with FusionDefinition(fs) as fd:
nv_pred = fd.define_tensor(sizes=pred.shape, strides=pred.stride(), dtype=DataType.Bool)
five = fd.define_constant(5.)
three = fd.define_constant(3.)
result = fd.ops.where(nv_pred, five, three)
fd.add_output(result)
nv_result = fs.execute((pred,))[0]
print(f"nv_result={nv_result}")
: nv_result=tensor([5., 3., 3., 3., 5.], device='cuda:0', dtype=torch.float64)
torch_result = torch.where(pred, 5., 3.)
print(f"torch_result={torch_result}")
: torch_result=tensor([5., 3., 3., 3., 5.], device='cuda:0')
The solution is to change the define_constant
API to allow for type specification since python numbers are inferred to be double
for floating point.
And define_scalar
, too?
define_scalar
already allows you to specify a type. I think we are okay there unless you saw an issue?
define_scalar
already allows you to specify a type. I think we are okay there unless you saw an issue?
My mistake, I didn't realize
We currently don't have single-precision scalars, neither in Python nor C++. I hit some errors trying to add those (see #2403), but it could probably be done. However, it may be simpler to add a dtype
to the where
op that defaults to DataType.Float
. The effect of the argument would be to insert a cast op after where
.
@mruberry the following lines in the PR branch above allow you to force the constant DataTypes to Float, which for where
is enough to ensure a float32
-valued output:
pytorch/third_party/nvfuser/python_tests/test_python_frontend.py
Lines 715 to 718 in 15035c2
Would that sufficiently address this issue? Note that we haven't changed the promotion rules for nvfuser: if an op receives only scalar floating point arguments, we do not use default floating type as is done in pytorch, but rather the highest-precision type of the given arguments.
@mruberry the following lines in the PR branch above allow you to force the constant DataTypes to Float, which for
where
is enough to ensure afloat32
-valued output:pytorch/third_party/nvfuser/python_tests/test_python_frontend.py
Lines 715 to 718 in 15035c2
Would that sufficiently address this issue? Note that we haven't changed the promotion rules for nvfuser: if an op receives only scalar floating point arguments, we do not use default floating type as is done in pytorch, but rather the highest-precision type of the given arguments.
Yes, I think that would address the issue