算子传入但未使用 dtype 参数
lljbash opened this issue · 0 comments
lljbash commented
例如 sum、prod
测试代码
@dtypes(torch.float16, torch.float32)
def test_prod_gpu(self, device, dtype):
x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device)
# Check all combinations: fp16 input - fp16 output, fp16 input - fp32
# output, fp32 input - fp16 output, fp32 input - fp32 output
for dtype_output in [torch.float16, torch.float32]:
result_expected = torch.tensor(2592, dtype=dtype_output, device=device)
output = torch.prod(x, dtype=dtype_output)
self.assertEqual(output, result_expected)
output = x.prod(dtype=dtype_output)
self.assertEqual(output, result_expected)
期望行为
以 dtype 格式运行
实际行为
以输入格式运行