[Linalg] linalg.matmul f16 type convert to f32
Closed this issue · 0 comments
CoTinker commented
Such aten.mm with f16 element type
func.func @matmul(%arg0: !torch.vtensor<[1500,1024],f16>, %arg1: !torch.vtensor<[1024,1024],f16>) -> (!torch.vtensor<[1500,1024],f16>) {
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[1500,1024],f16>, !torch.vtensor<[1024,1024],f16> -> !torch.vtensor<[1500,1024],f16>
return %0 : !torch.vtensor<[1500,1024],f16>
}
is converted to:
module {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @matmul(%arg0: tensor<1500x1024xf16>, %arg1: tensor<1024x1024xf16>) -> tensor<1500x1024xf16> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1500x1024xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1500x1024xf32>) -> tensor<1500x1024xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<1500x1024xf16>, tensor<1024x1024xf16>) outs(%1 : tensor<1500x1024xf32>) -> tensor<1500x1024xf32>
%3 = tensor.empty() : tensor<1500x1024xf16>
%4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<1500x1024xf32>) outs(%3 : tensor<1500x1024xf16>) {
^bb0(%in: f32, %out: f16):
%5 = arith.truncf %in : f32 to f16
linalg.yield %5 : f16
} -> tensor<1500x1024xf16>
return %4 : tensor<1500x1024xf16>
}
}
The result's element type of linalg.matmul
is f32.
The code is:
torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp
Lines 151 to 156 in 9a6fe58
torch-mlir/lib/Dialect/Torch/Utils/Utils.cpp
Lines 611 to 613 in 9a6fe58
I don't know why it's designed that way.