plaidml/tpp-mlir

PyTorch MLP doesn't fuse XSMM

Closed this issue ยท 7 comments

Due to temporary buffers (as predicted), the XSMM pass does not fuse PyTorch inputs.

Inner loop for mlir-gen:

  scf.forall (%arg1, %arg2) in (8, 8) {
    %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    %3 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
    xsmm.unary zero(data_type = f32, %3, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %subview_3 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
    xsmm.brgemm(data_type = f32, %4, %subview_3, %2, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
    %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = f32
    xsmm.binary add(data_type = f32, %5, %1, %subview, %subview) : (i64, memref<32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.unary relu(data_type = f32, %6, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
  }

Fuses to:

  scf.forall (%arg1, %arg2) in (8, 8) {
    %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    %3 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
    xsmm.unary zero(data_type = f32, %3, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %subview_3 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %4 = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu]  flags = (none)  binary_flags = (bcast_col_in0)  unary_flags = (none) data_type = f32
    xsmm.fused_brgemm(data_type = f32, %4, %subview_3, %2, %subview, %1, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32xf32>, i64) -> ()
  }

Inner loop for PyTorch output:

  scf.forall (%arg1, %arg2) in (8, 32) {
    %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xbf16>
    %6 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = bf16
    xsmm.unary zero(data_type = bf16, %6, %cst, %alloc_3) : (i64, bf16, memref<32x32xbf16>) -> ()
    %subview = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
    %7 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b) data_type = bf16
    xsmm.brgemm(data_type = bf16, %7, %subview, %5, %alloc_3, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16>, i64) -> ()
    %8 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = bf16
    xsmm.binary add(data_type = bf16, %8, %4, %alloc_3, %alloc_3) : (i64, memref<32xbf16>, memref<32x32xbf16>, memref<32x32xbf16>) -> ()
    %subview_4 = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>>
    %9 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16
    xsmm.unary relu(data_type = bf16, %9, %alloc_3, %subview_4) : (i64, memref<32x32xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> ()
    memref.dealloc %alloc_3 : memref<32x32xbf16>
  }

Does not fuse. Note the extra %alloc_3 inside the loop being used as output for the brgemm and add but not relu.

@KavithaTipturMadhu, this is the pattern we discussed that would happen and not trigger the current pass.

@alheinecke

Snippets from file 048.mlir when running the debug script on tpp-mlir and the PyTorch file attached here.
pytorch-xsmm.mlir.txt
tpp-mlir-xsmm.mlir.txt
torch-dynamo-mlp-bf16-3x1024.mlir.txt

Reproduce:

# PyTorch
./scripts/debug/debug_all_passes.sh -i torch-dynamo-mlp-bf16-3x1024.mlir
# mlir-gen
./scripts/debug/debug_all_passes.sh

Inspect files 048.mlir (xsmm dialect pre-fusion) and 049.mlir (xsmm post fusion).

001.mlir.txt
017.mlir.txt
048.mlir.txt
I modified the pytorch mlir input's relu function from compare-select to maximumf function. This can be seen in 001.mlir and in 048.mlir where the xsmm relu call now uses both subview arguments. The xsmm add operation, however, still uses one alloc and one subview operand which is hampering fusion because the pytorch model originally used matmul named op instead of mulf-addf generic operation, which lowers to a tensor empty, as can be seen in 017.mlir file. There's nothing technically incorrect about writing the model using matmul operation or compare-select operation for relu, but they introduce additional allocs that we can't get rid of and these are blocking fusion. What do you recommend is the best course of action here? @rengolin @alheinecke

mlir-gen-vs-pytorch
I can see a few patterns here:

  1. PyTorch transposes the weight while mlir-gen (TensorFlow) does not. The result of the transpose goes into a tensor.empty which is not a constant, while in the mlir-gen case, it is. This may affect read/write analysis.
  2. PyTorch uses linalg.matmul while mlir-gen uses linalg.generic(mulf+addf). There could be tiling semantics that are available to generics (because it was coded for them) and not named ops that also implement the TilingInterface.
  3. PyTorch lowers ReLU as linalg.generic(cmpf+select) while mlir-gen lowers to linalg.generic(maximumf). We convert this later (at XSMM level) to ReLU, but the tiling and bufferization is still different. Maybe multiple instructions make it harder to detect patterns in a generic?
  4. PyTorch uses the previous result in the def-use chain as input and the "shared" buffer as init while mlir-gen has no input and uses the previous value as init. This is bound to be the biggest problem for buffer allocation.
  5. PyTorch use of the "shared" buffer (for DPS purposes) is on the tensor.empty not the linalg.fill, potentially creating a data dependency problem: the former is uninitialized memory and the latter is not, so the memory contents are different, even though one of them is uninitialized. Perhaps having the init as the result of linalg.fill could alleviate the problem.

The course of action I would take is to slowly convert the PyTorch code into our generated code, one at a time, and see what changes after tiling. This would give us an idea on where to look in the MLIR code (passes, dialects, interfaces) to search for clues as to what is happening.

I'd do in this order:

  • Change the init for Add and ReLU to the results of linalg.fill instead of tensor.empty.
  • Changing the "ins/outs" style of PyTorch to the "outs" style of mlir-gen without changing anything else first.
  • Remove the transpose and make %cst_4 a direct argument to linalg.matmul (in its transposed shape).
  • Change linalg.matmul to linalg.generic(mulf+addf).
  • Change linalg.generic(cmpf+select) to linalg.generic(maximumf).

This is my uneducated guess as to what could be impacting the tiling pass more, but you should do them independently of each other, in case my order above it wrong.

Depending on what leads to the most impact above will dictate how we approach the problem. I'm foreseeing some changes to the tiling and bufferization passes upstream. Depending on the impact upstream, we may have to propose a more robust in-place/out-of-place mechanism for tiling and bufferization (for ex. demove the DPS and add some required attribute).

I have tried out all the changes you've recommended above and that's resulted in the exact same IR as mlir-gen, and hence encountered no problem with fusion. Now that brings us to the second part of the question, this is also valid IR that we ought to support.

I think the important question is: which effect did each change have, and why? We need to understand every suggestion I gave independently, to form a roadmap for fixing this upstream. Let's start a document where we gather all those results and see what are the parts of MLIR we need to fix.

  • Change the init for Add and ReLU to the results of linalg.fill instead of tensor.empty.
    The fills are dce'd in the former case at tile and fuse and further as the filled buffer isn't used but the tensor.empty buffer is used. Therefore, relu's input are not initialized with xsmm.unary zero call either. However, xsmm.add has two alloc arguments whereas relu has an alloc argument and a subview argument:
 scf.forall (%arg1, %arg2) in (8, 32) {
    %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
    %6 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
    xsmm.unary zero(data_type = f32, %6, %cst, %alloc_3) : (i64, f32, memref<32x32xf32>) -> ()
    %subview = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %7 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
    xsmm.brgemm(data_type = f32, %7, %subview, %5, %alloc_3, %c32_i64) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32>, i64) -> ()
    %8 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = f32
    xsmm.binary add(data_type = f32, %8, %4, %alloc_3, %alloc_3) : (i64, memref<32xf32>, memref<32x32xf32>, memref<32x32xf32>) -> ()
    %subview_4 = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    %9 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
    xsmm.unary zero(data_type = f32, %9, %cst, %subview_4) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %10 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.unary relu(data_type = f32, %10, %alloc_3, %subview_4) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    memref.dealloc %alloc_3 : memref<32x32xf32>
  }
  • Changing the "ins/outs" style of PyTorch to the "outs" style of mlir-gen without changing anything else first.
    No differences with respect to the previous state because cse chains the outputs in pytorch the same way as mlir-gen.

  • Remove the transpose and make %cst_4 a direct argument to linalg.matmul (in its transposed shape).
    Again, no differences

  • Change linalg.matmul to linalg.generic(mulf+addf).
    An explicit broadcast for add is performed prior to changing matmul to generic, in which case a tensor.empty and a tensor.fill operation is inserted in Tile and fuse pass to initialize the temporary buffer, before the add operation between the matmul and relu operations. This buffer is no longer inserted when the matmul is converted to generic, as the result is held in-place.
    Before:

%2 = scf.forall (%arg1, %arg2) in (8, 32) shared_outs(%arg3 = %1) -> (tensor<8x32x32x32xf32>) {
    %5 = tensor.empty() : tensor<32x32xf32>
    %6 = linalg.fill ins(%cst_5 : f32) outs(%5 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %extracted_slice = tensor.extract_slice %pack[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32x32xf32>
    %7 = linalg.batch_reduce_matmul ins(%extracted_slice, %cst_4 : tensor<32x32x32xf32>, tensor<32x32x32xf32>) outs(%6 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %8 = tensor.empty() : tensor<32x32xf32>
    %9 = linalg.fill ins(%cst_5 : f32) outs(%8 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_3, %7 : tensor<32xf32>, tensor<32x32xf32>) outs(%9 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %in_7: f32, %out: f32):
      %13 = arith.addf %in, %in_7 : f32
      linalg.yield %13 : f32
    } -> tensor<32x32xf32>

After:

%2 = scf.forall (%arg1, %arg2) in (8, 32) shared_outs(%arg3 = %1) -> (tensor<8x32x32x32xf32>) {
    %5 = tensor.empty() : tensor<32x32xf32>
    %6 = linalg.fill ins(%cst_4 : f32) outs(%5 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %extracted_slice = tensor.extract_slice %pack[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32x32xf32>
    %7 = linalg.batch_reduce_matmul ins(%extracted_slice, %cst_3 : tensor<32x32x32xf32>, tensor<32x32x32xf32>) outs(%6 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_2 : tensor<32xf32>) outs(%7 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = arith.addf %in, %out : f32
      linalg.yield %11 : f32
    } -> tensor<32x32xf32>
  • Change linalg.generic(cmpf+select) to linalg.generic(maximumf).
    This further gets rid of another tensor.empty and linalg.fill operation pair in Tile and fuse pass' output:
 %2 = scf.forall (%arg1, %arg2) in (8, 32) shared_outs(%arg3 = %1) -> (tensor<8x32x32x32xf32>) {
    %extracted_slice = tensor.extract_slice %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32xf32>
    %5 = linalg.fill ins(%cst_4 : f32) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
    %extracted_slice_5 = tensor.extract_slice %pack[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32x32xf32>
    %6 = linalg.batch_reduce_matmul ins(%extracted_slice_5, %cst_3 : tensor<32x32x32xf32>, tensor<32x32x32xf32>) outs(%5 : tensor<32x32xf32>) -> tensor<32x32xf32>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_2 : tensor<32xf32>) outs(%6 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %9 = arith.addf %in, %out : f32
      linalg.yield %9 : f32
    } -> tensor<32x32xf32>
    %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%7 : tensor<32x32xf32>) {
    ^bb0(%out: f32):
      %9 = arith.maximumf %out, %cst_4 : f32
      linalg.yield %9 : f32
    } -> tensor<32x32xf32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %8 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x32x32x32xf32>
    }
  }

This results in matching operands in add and relu in the input to CombineXsmmOpPass as follows:

 scf.forall (%arg1, %arg2) in (8, 32) {
    %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    %5 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
    xsmm.unary zero(data_type = f32, %5, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %subview_3 = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %6 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
    xsmm.brgemm(data_type = f32, %6, %subview_3, %4, %subview, %c32_i64) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
    %7 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = f32
    xsmm.binary add(data_type = f32, %7, %3, %subview, %subview) : (i64, memref<32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %8 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.unary relu(data_type = f32, %8, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
  }

And this fuses into fused brgemm operations:

 scf.forall (%arg1, %arg2) in (8, 32) {
    %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    %5 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
    xsmm.unary zero(data_type = f32, %5, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %subview_3 = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
    %6 = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu]  flags = (none)  binary_flags = (bcast_col_in0)  unary_flags = (none) data_type = f32
    xsmm.fused_brgemm(data_type = f32, %6, %subview_3, %4, %subview, %3, %c32_i64) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32xf32>, i64) -> ()
  }

Awesome! Ok, so this confirms our suspicions that the generic tiling is the culprit.

We have long discussed an ingress canonicalization pass, and I think it's time we have one. Not as a way to "work around" the issue, but to allow us to "work on" the issue in an orthogonal manner.

We can still fix the tiling problems, but not while it's hampering our ability to transform IR into optimal XSMM calls.

So, the "solution" for this issue is to add an early pass that:

  • Converts linalg.matmul into linalg.generic { mulf + addf } (I think there's such conversion upstream already - linalg generalization)
  • Converts linalg.generic { cmpf + select } into linalg.generic { maximumf }

Then create two issues upstream (in llvm-project) to track fixing those problems in the upstream tiling pass.