plaidml/tpp-mlir

Linalg to XSMM lowering not detecting broadcast as XSMM

Closed this issue · 4 comments

Reproduction:

./build/bin/tpp-opt ./test/Passes/DefaultPipeline/default-tpp-passes.mlir -default-tpp-passes="linalg-to-xsmm" -split-input-file

Expected behaviour (@mlp function):

    // Broadcast before GEMM lowered
    %0 = call @xsmm_unary_dispatch(%c1_i64, %c1_i64, %c128_i64, %c512_i64, %c512_i64, %c512_i64, %c4_i64) : (i64, i64, i64, i64, i64, i64, i64) -> i64
    %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<512xf32> -> index
    %1 = arith.index_cast %intptr : index to i64
    %2 = llvm.inttoptr %1 : i64 to !llvm.ptr<f32>
    %intptr_0 = memref.extract_aligned_pointer_as_index %alloc : memref<128x512xf32> -> index
    %3 = arith.index_cast %intptr_0 : index to i64
    %4 = llvm.inttoptr %3 : i64 to !llvm.ptr<f32>
    call @xsmm_unary_invoke(%c1_i64, %0, %2, %c0, %4, %c0) : (i64, i64, !llvm.ptr<f32>, index, !llvm.ptr<f32>, index) -> ()

Actual behaviour (@mlp function):

    // Broadcast before GEMM untouched
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2 : memref<512xf32>) outs(%alloc : memref<128x512xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    }

Everything else works in that test when moving to the new lowering.

@chelini

Looking at the actual test, this is not an MLP and it makes no sense.

When changing it to an actual MLP, I get the following (expected) behaviour, with the broadcast still not expanded (but in the right place).

From:

#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @mlp(%arg0: tensor<128x256xf32>,
               %arg1: tensor<256x512xf32>,
               %arg2: tensor<512xf32>,
               %output: tensor<128x512xf32>) -> tensor<128x512xf32> {

  // Bias broadcast
  %bias = tensor.empty() : tensor<128x512xf32>
  %1 = linalg.generic {
            indexing_maps = [#map0, #map1],
            iterator_types = ["parallel", "parallel"]
          }
          ins(%arg2 : tensor<512xf32>)
          outs(%bias : tensor<128x512xf32>) {
            ^bb0(%arg9: f32, %arg10: f32):
              linalg.yield %arg9 : f32
          } -> tensor<128x512xf32>

  // Matmul
  %2 = linalg.generic {
            indexing_maps = [#map2, #map3, #map4],
            iterator_types = ["parallel", "parallel", "reduction"]
          }
          ins(%arg0, %arg1 : tensor<128x256xf32>, tensor<256x512xf32>)
          outs(%output : tensor<128x512xf32>) {
            ^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
              %16 = arith.mulf %arg9, %arg10 : f32
              %17 = arith.addf %arg11, %16 : f32
              linalg.yield %17 : f32
          } -> tensor<128x512xf32>

  // Bias Add
  %3 = linalg.generic {
            indexing_maps = [#map1, #map1],
            iterator_types = ["parallel", "parallel"]
          }
          ins(%1 : tensor<128x512xf32>)
          outs(%2 : tensor<128x512xf32>) {
          ^bb0(%arg9: f32, %arg10: f32):
              %16 = arith.addf %arg9, %arg10 : f32
              linalg.yield %16 : f32
          } -> tensor<128x512xf32>


  // ReLU
  %c0 = arith.constant 0.0 : f32
  %4 = linalg.generic {
            indexing_maps = [#map1],
            iterator_types = ["parallel", "parallel"]
          }
          outs(%3 : tensor<128x512xf32>) {
            ^bb0(%arg9: f32):
              %16 = arith.maxf %arg9, %c0 : f32
              linalg.yield %16 : f32
          } -> tensor<128x512xf32>

  return %4 : tensor<128x512xf32>
}

To:

// -----// IR Dump After TileConsumerAndFuseProducers (tile-consumer-and-fuse-producers) //----- //
func.func @mlp(%arg0: tensor<128x256xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<128x512xf32>) -> tensor<128x512xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<4x16x32x32xf32>
  %expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<512xf32> into tensor<16x32xf32>
  %1 = tensor.empty() : tensor<4x8x32x32xf32>
  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32>
  %2 = tensor.empty() : tensor<16x8x32x32xf32>
  %pack_0 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<256x512xf32> -> tensor<16x8x32x32xf32>
  %pack_1 = tensor.pack %arg3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32>
  %3 = scf.forall (%arg4, %arg5) in (4, 16) shared_outs(%arg6 = %pack_1) -> (tensor<4x16x32x32xf32>) {
    %extracted_slice = tensor.extract_slice %expanded[%arg5, 0] [1, 32] [1, 1] : tensor<16x32xf32> to tensor<32xf32>
    %4 = tensor.empty() : tensor<32x32xf32>
    %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32xf32>) outs(%4 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<32x32xf32>
    %extracted_slice_2 = tensor.extract_slice %pack[%arg4, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<8x32x32xf32>
    %extracted_slice_3 = tensor.extract_slice %pack_0[%arg5, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : tensor<16x8x32x32xf32> to tensor<8x32x32xf32>
    %extracted_slice_4 = tensor.extract_slice %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<32x32xf32>
    %6 = linalg.batch_reduce_matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<8x32x32xf32>, tensor<8x32x32xf32>) outs(%extracted_slice_4 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<32x32xf32>) outs(%6 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %9 = arith.addf %in, %out : f32
      linalg.yield %9 : f32
    } -> tensor<32x32xf32>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%7 : tensor<32x32xf32>) {
    ^bb0(%out: f32):
      %9 = arith.maxf %out, %cst : f32
      linalg.yield %9 : f32
    } -> tensor<32x32xf32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %8 into %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x16x32x32xf32>
    }
  }
  %unpack = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg3 : tensor<4x16x32x32xf32> -> tensor<128x512xf32>
  return %unpack : tensor<128x512xf32>
}

Actually, looking further down, all looks good when we actually have a real MLP:

// -----// IR Dump After ConvertLinalgToXsmm (convert-linalg-to-xsmm) //----- //
...
  scf.forall (%arg4, %arg5) in (4, 16) {
    %subview = memref.subview %expand_shape[%arg5, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<32xf32, strided<[1], offset: ?>>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
    %0 = xsmm.unary.dispatch identity [32, 32, 1, 32] flags = (bcast_col) data_type = f32
    xsmm.unary identity(data_type = f32, %0, %subview, %alloc_2) : (i64, memref<32xf32, strided<[1], offset: ?>>, memref<32x32xf32>) -> ()
    %subview_3 = memref.subview %alloc_0[%arg4, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %subview_4 = memref.subview %alloc_1[%arg5, 0, 0, 0] [1, 8, 32, 32] [1, 1, 1, 1] : memref<16x8x32x32xf32> to memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %subview_5 = memref.subview %alloc[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
    xsmm.brgemm(data_type = f32, %1, %subview_3, %subview_4, %subview_5, %c8_i64) : (i64, memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<8x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
    %2 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.binary add(data_type = f32, %2, %alloc_2, %subview_5, %subview_5) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %3 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.unary relu(data_type = f32, %3, %subview_5, %subview_5) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    memref.dealloc %alloc_2 : memref<32x32xf32>
  }
...

The xsmm.unary.dispatch identity here is a fill, it is not the 'bias broadcast' linalg.generic above.

The xsmm.unary.dispatch identity here is a fill, it is not the 'bias broadcast' linalg.generic above.

I see, no worries. We can have another test later. Thanks!