DeepLink-org/deeplink.framework

算子传入但未使用 dtype 参数

lljbash opened this issue · 0 comments

例如 sum、prod

测试代码

@dtypes(torch.float16, torch.float32)
def test_prod_gpu(self, device, dtype):
      x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device)

      # Check all combinations: fp16 input - fp16 output, fp16 input - fp32
      # output, fp32 input - fp16 output, fp32 input - fp32 output
      for dtype_output in [torch.float16, torch.float32]:
          result_expected = torch.tensor(2592, dtype=dtype_output, device=device)
          output = torch.prod(x, dtype=dtype_output)
          self.assertEqual(output, result_expected)

          output = x.prod(dtype=dtype_output)
          self.assertEqual(output, result_expected)

期望行为

以 dtype 格式运行

实际行为

以输入格式运行