
Linalg to XSMM lowering not detecting broadcast as XSMM

./build/bin/tpp-opt ./test/Passes/DefaultPipeline/default-tpp-passes.mlir -default-tpp-passes="linalg-to-xsmm" -split-input-file

Expected behaviour (@mlp function):

    // Broadcast before GEMM lowered
    %0 = call @xsmm_unary_dispatch(%c1_i64, %c1_i64, %c128_i64, %c512_i64, %c512_i64, %c512_i64, %c4_i64) : (i64, i64, i64, i64, i64, i64, i64) -> i64
    %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<512xf32> -> index
    %1 = arith.index_cast %intptr : index to i64
    %2 = llvm.inttoptr %1 : i64 to !llvm.ptr<f32>
    %intptr_0 = memref.extract_aligned_pointer_as_index %alloc : memref<128x512xf32> -> index
    %3 = arith.index_cast %intptr_0 : index to i64
    %4 = llvm.inttoptr %3 : i64 to !llvm.ptr<f32>
    call @xsmm_unary_invoke(%c1_i64, %0, %2, %c0, %4, %c0) : (i64, i64, !llvm.ptr<f32>, index, !llvm.ptr<f32>, index) -> ()

Actual behaviour (@mlp function):

    // Broadcast before GEMM untouched
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<512xf32>) outs(%alloc : memref<128x512xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32

Everything else works in that test when moving to the new lowering.


Looking at the actual test, this is not an MLP and it makes no sense.

When changing it to an actual MLP, I get the following (expected) behaviour, with the broadcast still not expanded (but in the right place).


#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @mlp(%arg0: tensor<128x256xf32>,
               %arg1: tensor<256x512xf32>,
               %arg2: tensor<512xf32>,
               %output: tensor<128x512xf32>) -> tensor<128x512xf32> {

  // Bias broadcast
  %bias = tensor.empty() : tensor<128x512xf32>
  %1 = linalg.generic {
            indexing_maps = [#map0, #map1],
            iterator_types = ["parallel", "parallel"]
          ins(%arg2 : tensor<512xf32>)
          outs(%bias : tensor<128x512xf32>) {
            ^bb0(%arg9: f32, %arg10: f32):
              linalg.yield %arg9 : f32
          } -> tensor<128x512xf32>

  // Matmul
  %2 = linalg.generic {
            indexing_maps = [#map2, #map3, #map4],
            iterator_types = ["parallel", "parallel", "reduction"]
          ins(%arg0, %arg1 : tensor<128x256xf32>, tensor<256x512xf32>)
          outs(%output : tensor<128x512xf32>) {
            ^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
              %16 = arith.mulf %arg9, %arg10 : f32
              %17 = arith.addf %arg11, %16 : f32
              linalg.yield %17 : f32
          } -> tensor<128x512xf32>

  // Bias Add
  %3 = linalg.generic {
            indexing_maps = [#map1, #map1],
            iterator_types = ["parallel", "parallel"]
          ins(%1 : tensor<128x512xf32>)
          outs(%2 : tensor<128x512xf32>) {
          ^bb0(%arg9: f32, %arg10: f32):
              %16 = arith.addf %arg9, %arg10 : f32
              linalg.yield %16 : f32
          } -> tensor<128x512xf32>

  // ReLU
  %c0 = arith.constant 0.0 : f32
  %4 = linalg.generic {
            indexing_maps = [#map1],
            iterator_types = ["parallel", "parallel"]
          outs(%3 : tensor<128x512xf32>) {
            ^bb0(%arg9: f32):
              %16 = arith.maxf %arg9, %c0 : f32
              linalg.yield %16 : f32
          } -> tensor<128x512xf32>

  return %4 : tensor<128x512xf32>


// -----// IR Dump After TileConsumerAndFuseProducers (tile-consumer-and-fuse-producers) //----- //
func.func @mlp(%arg0: tensor<128x256xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<128x512xf32>) -> tensor<128x512xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<4x16x32x32xf32>
  %expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<512xf32> into tensor<16x32xf32>
  %1 = tensor.empty() : tensor<4x8x32x32xf32>
  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32>
  %2 = tensor.empty() : tensor<16x8x32x32xf32>
  %pack_0 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<256x512xf32> -> tensor<16x8x32x32xf32>
  %pack_1 = tensor.pack %arg3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32>
  %3 = scf.forall (%arg4, %arg5) in (4, 16) shared_outs(%arg6 = %pack_1) -> (tensor<4x16x32x32xf32>) {
    %extracted_slice = tensor.extract_slice %expanded[%arg5, 0] [1, 32] [1, 1] : tensor<16x32xf32> to tensor<32xf32>
    %4 = tensor.empty() : tensor<32x32xf32>
    %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32xf32>) outs(%4 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<32x32xf32>
    %extracted_slice_2 = tensor.extract_slice %pack[%arg4, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<8x32x32xf32>
    %extracted_slice_3 = tensor.extract_slice %pack_0[%arg5, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : tensor<16x8x32x32xf32> to tensor<8x32x32xf32>
    %extracted_slice_4 = tensor.extract_slice %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<32x32xf32>
    %6 = linalg.batch_reduce_matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<8x32x32xf32>, tensor<8x32x32xf32>) outs(%extracted_slice_4 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<32x32xf32>) outs(%6 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %9 = arith.addf %in, %out : f32
      linalg.yield %9 : f32
    } -> tensor<32x32xf32>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%7 : tensor<32x32xf32>) {
    ^bb0(%out: f32):
      %9 = arith.maxf %out, %cst : f32
      linalg.yield %9 : f32
    } -> tensor<32x32xf32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %8 into %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x16x32x32xf32>
  %unpack = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg3 : tensor<4x16x32x32xf32> -> tensor<128x512xf32>
  return %unpack : tensor<128x512xf32>

Actually, looking further down, all looks good when we actually have a real MLP:

// -----// IR Dump After ConvertLinalgToXsmm (convert-linalg-to-xsmm) //----- //
  scf.forall (%arg4, %arg5) in (4, 16) {
    %subview = memref.subview %expand_shape[%arg5, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<32xf32, strided<[1], offset: ?>>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
    %0 = xsmm.unary.dispatch identity [32, 32, 1, 32] flags = (bcast_col) data_type = f32
    xsmm.unary identity(data_type = f32, %0, %subview, %alloc_2) : (i64, memref<32xf32, strided<[1], offset: ?>>, memref<32x32xf32>) -> ()
    %subview_3 = memref.subview %alloc_0[%arg4, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %subview_4 = memref.subview %alloc_1[%arg5, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : memref<16x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %subview_5 = memref.subview %alloc[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
    xsmm.brgemm(data_type = f32, %1, %subview_3, %subview_4, %subview_5, %c8_i64) : (i64, memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
    %2 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.binary add(data_type = f32, %2, %alloc_2, %subview_5, %subview_5) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %3 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.unary relu(data_type = f32, %3, %subview_5, %subview_5) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    memref.dealloc %alloc_2 : memref<32x32xf32>

The xsmm.unary.dispatch identity here is a fill, it is not the 'bias broadcast' linalg.generic above.

I see, no worries. We can have another test later. Thanks!