plaidml/tpp-mlir

Canonicalize affine.min map when tensor are dynamics

Closed this issue · 0 comments

When re-writing a batch matmul to matmul we tile fully along the batch dimension, however when the tensors are fully dynamic, the scf.forall parallelization introduces an affine min map that prevents the rank reduction and hence the mapping to brgemm.

func.func @batch_matmul_rewrite(%arg0: tensor<?x?x?xf32>,
  %arg1: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %bacth: index) -> tensor<?x?x?xf32> {
  %0 = tensor.empty(%bacth, %dim0, %dim1) : tensor<?x?x?xf32>
  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
                           outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
  return %1 : tensor<?x?x?xf32>
}

To reproduce consider the above example and use -rewrite-batch-matmul-to-matmul. The fix needs to happen upstream
in the scf.forall tiling.