llvm/torch-mlir

[Linalg] linalg.matmul f16 type convert to f32

Closed this issue · 0 comments

Such aten.mm with f16 element type

func.func @matmul(%arg0: !torch.vtensor<[1500,1024],f16>, %arg1: !torch.vtensor<[1024,1024],f16>) -> (!torch.vtensor<[1500,1024],f16>) {
  %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[1500,1024],f16>, !torch.vtensor<[1024,1024],f16> -> !torch.vtensor<[1500,1024],f16>
  return %0 : !torch.vtensor<[1500,1024],f16>
}

is converted to:

module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @matmul(%arg0: tensor<1500x1024xf16>, %arg1: tensor<1024x1024xf16>) -> tensor<1500x1024xf16> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<1500x1024xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1500x1024xf32>) -> tensor<1500x1024xf32>
    %2 = linalg.matmul ins(%arg0, %arg1 : tensor<1500x1024xf16>, tensor<1024x1024xf16>) outs(%1 : tensor<1500x1024xf32>) -> tensor<1500x1024xf32>
    %3 = tensor.empty() : tensor<1500x1024xf16>
    %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<1500x1024xf32>) outs(%3 : tensor<1500x1024xf16>) {
    ^bb0(%in: f32, %out: f16):
      %5 = arith.truncf %in : f32 to f16
      linalg.yield %5 : f16
    } -> tensor<1500x1024xf16>
    return %4 : tensor<1500x1024xf16>
  }
}

The result's element type of linalg.matmul is f32.

The code is:

Type elementType = resultType.getElementType();
auto accumulatorDType =
getDefaultAccType(rewriter, lhsType.getElementType());
if (accumulatorDType != resultType.getElementType()) {
elementType = accumulatorDType;
}

Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
if (inputType.isF16())
return rewriter.getF32Type();

I don't know why it's designed that way.