Linalg to XSMM lowering not detecting broadcast as XSMM
Closed this issue · 4 comments
rengolin commented
Reproduction:
./build/bin/tpp-opt ./test/Passes/DefaultPipeline/default-tpp-passes.mlir -default-tpp-passes="linalg-to-xsmm" -split-input-file
Expected behaviour (@mlp
function):
// Broadcast before GEMM lowered
%0 = call @xsmm_unary_dispatch(%c1_i64, %c1_i64, %c128_i64, %c512_i64, %c512_i64, %c512_i64, %c4_i64) : (i64, i64, i64, i64, i64, i64, i64) -> i64
%intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<512xf32> -> index
%1 = arith.index_cast %intptr : index to i64
%2 = llvm.inttoptr %1 : i64 to !llvm.ptr<f32>
%intptr_0 = memref.extract_aligned_pointer_as_index %alloc : memref<128x512xf32> -> index
%3 = arith.index_cast %intptr_0 : index to i64
%4 = llvm.inttoptr %3 : i64 to !llvm.ptr<f32>
call @xsmm_unary_invoke(%c1_i64, %0, %2, %c0, %4, %c0) : (i64, i64, !llvm.ptr<f32>, index, !llvm.ptr<f32>, index) -> ()
Actual behaviour (@mlp
function):
// Broadcast before GEMM untouched
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<512xf32>) outs(%alloc : memref<128x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
}
Everything else works in that test when moving to the new lowering.
rengolin commented
Looking at the actual test, this is not an MLP and it makes no sense.
When changing it to an actual MLP, I get the following (expected) behaviour, with the broadcast still not expanded (but in the right place).
From:
#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @mlp(%arg0: tensor<128x256xf32>,
%arg1: tensor<256x512xf32>,
%arg2: tensor<512xf32>,
%output: tensor<128x512xf32>) -> tensor<128x512xf32> {
// Bias broadcast
%bias = tensor.empty() : tensor<128x512xf32>
%1 = linalg.generic {
indexing_maps = [#map0, #map1],
iterator_types = ["parallel", "parallel"]
}
ins(%arg2 : tensor<512xf32>)
outs(%bias : tensor<128x512xf32>) {
^bb0(%arg9: f32, %arg10: f32):
linalg.yield %arg9 : f32
} -> tensor<128x512xf32>
// Matmul
%2 = linalg.generic {
indexing_maps = [#map2, #map3, #map4],
iterator_types = ["parallel", "parallel", "reduction"]
}
ins(%arg0, %arg1 : tensor<128x256xf32>, tensor<256x512xf32>)
outs(%output : tensor<128x512xf32>) {
^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
%16 = arith.mulf %arg9, %arg10 : f32
%17 = arith.addf %arg11, %16 : f32
linalg.yield %17 : f32
} -> tensor<128x512xf32>
// Bias Add
%3 = linalg.generic {
indexing_maps = [#map1, #map1],
iterator_types = ["parallel", "parallel"]
}
ins(%1 : tensor<128x512xf32>)
outs(%2 : tensor<128x512xf32>) {
^bb0(%arg9: f32, %arg10: f32):
%16 = arith.addf %arg9, %arg10 : f32
linalg.yield %16 : f32
} -> tensor<128x512xf32>
// ReLU
%c0 = arith.constant 0.0 : f32
%4 = linalg.generic {
indexing_maps = [#map1],
iterator_types = ["parallel", "parallel"]
}
outs(%3 : tensor<128x512xf32>) {
^bb0(%arg9: f32):
%16 = arith.maxf %arg9, %c0 : f32
linalg.yield %16 : f32
} -> tensor<128x512xf32>
return %4 : tensor<128x512xf32>
}
To:
// -----// IR Dump After TileConsumerAndFuseProducers (tile-consumer-and-fuse-producers) //----- //
func.func @mlp(%arg0: tensor<128x256xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<128x512xf32>) -> tensor<128x512xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x16x32x32xf32>
%expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<512xf32> into tensor<16x32xf32>
%1 = tensor.empty() : tensor<4x8x32x32xf32>
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32>
%2 = tensor.empty() : tensor<16x8x32x32xf32>
%pack_0 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<256x512xf32> -> tensor<16x8x32x32xf32>
%pack_1 = tensor.pack %arg3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32>
%3 = scf.forall (%arg4, %arg5) in (4, 16) shared_outs(%arg6 = %pack_1) -> (tensor<4x16x32x32xf32>) {
%extracted_slice = tensor.extract_slice %expanded[%arg5, 0] [1, 32] [1, 1] : tensor<16x32xf32> to tensor<32xf32>
%4 = tensor.empty() : tensor<32x32xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32xf32>) outs(%4 : tensor<32x32xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x32xf32>
%extracted_slice_2 = tensor.extract_slice %pack[%arg4, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<8x32x32xf32>
%extracted_slice_3 = tensor.extract_slice %pack_0[%arg5, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : tensor<16x8x32x32xf32> to tensor<8x32x32xf32>
%extracted_slice_4 = tensor.extract_slice %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<32x32xf32>
%6 = linalg.batch_reduce_matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<8x32x32xf32>, tensor<8x32x32xf32>) outs(%extracted_slice_4 : tensor<32x32xf32>) -> tensor<32x32xf32>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<32x32xf32>) outs(%6 : tensor<32x32xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.addf %in, %out : f32
linalg.yield %9 : f32
} -> tensor<32x32xf32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%7 : tensor<32x32xf32>) {
^bb0(%out: f32):
%9 = arith.maxf %out, %cst : f32
linalg.yield %9 : f32
} -> tensor<32x32xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x16x32x32xf32>
}
}
%unpack = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg3 : tensor<4x16x32x32xf32> -> tensor<128x512xf32>
return %unpack : tensor<128x512xf32>
}
rengolin commented
Actually, looking further down, all looks good when we actually have a real MLP:
// -----// IR Dump After ConvertLinalgToXsmm (convert-linalg-to-xsmm) //----- //
...
scf.forall (%arg4, %arg5) in (4, 16) {
%subview = memref.subview %expand_shape[%arg5, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<32xf32, strided<[1], offset: ?>>
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
%0 = xsmm.unary.dispatch identity [32, 32, 1, 32] flags = (bcast_col) data_type = f32
xsmm.unary identity(data_type = f32, %0, %subview, %alloc_2) : (i64, memref<32xf32, strided<[1], offset: ?>>, memref<32x32xf32>) -> ()
%subview_3 = memref.subview %alloc_0[%arg4, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>
%subview_4 = memref.subview %alloc_1[%arg5, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : memref<16x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>
%subview_5 = memref.subview %alloc[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
%1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
xsmm.brgemm(data_type = f32, %1, %subview_3, %subview_4, %subview_5, %c8_i64) : (i64, memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
%2 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32
xsmm.binary add(data_type = f32, %2, %alloc_2, %subview_5, %subview_5) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
%3 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
xsmm.unary relu(data_type = f32, %3, %subview_5, %subview_5) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
memref.dealloc %alloc_2 : memref<32x32xf32>
}
...
chelini commented
The xsmm.unary.dispatch identity
here is a fill, it is not the 'bias broadcast' linalg.generic above.
rengolin commented
The
xsmm.unary.dispatch identity
here is a fill, it is not the 'bias broadcast' linalg.generic above.
I see, no worries. We can have another test later. Thanks!