plaidml/tpp-mlir

Implement unary transform TPP operations

Closed this issue · 3 comments

We currently only have one transform, tpp.identity which is a mixed bag. We decided to go with more explicit operations.

All TPP transforms are out-of-place operations, so they need the ins and outs to not be the same memref.

Ops to create:

  tpp.zero ins(%0) outs(%0) : memref<MxNxTy> // XOR op

  tpp.copy ins(%0) outs(%1) : (memref<MxNxTy>, memref<MxNxTy>) // COPY

  tpp.broadcast ins(%0) outs(%1) : (memref<2x1xf32>, memref<2x4xf32>) // COPY + BCAST_COL
  tpp.broadcast ins(%0) outs(%1) : (memref<2xf32>, memref<4x2xf32>) // COPY + BCAST_ROW
  tpp.broadcast ins(%0) outs(%1) : (memref<1xf32>, memref<4x2xf32>) // COPY + BCAST_SCALAR
  tpp.broadcast ins(%0) outs(%1) : (memref<f32>, memref<4x2xf32>) // COPY + BCAST_SCALAR
  tpp.broadcast ins(%0) outs(%1) : (f32, memref<4x2xf32>) // COPY + BCAST_SCALAR

  tpp.reduce_add ins(%0) outs(%1) : (memref<MxNxTy>, memref<MxTy>) // REDUCE_X_OP_ADD + REDUCE_ROWS
  tpp.reduce_max ins(%0) outs(%1) : (memref<MxNxTy>, memref<NxTy>) // REDUCE_X_OP_MAX + REDUCE_COLS
  ...

  tpp.transpose ins(%0) outs(%1)  : memref<MxNxTy> -> memref<NxMxTy> // TRANSFORM_NORM_TO_NORMT

  tpp.vnni_pack ins(%0) outs(%1)  : memref<MxNxTy> -> memref<NxMxTy> // TRANSFORM_NORM_TO_VNNI2

These are required to implement tensor.pack and tensor.unpack in TPP (#290).

@chelini @adam-smnk

With the upstream movement of https://reviews.llvm.org/D153422 and https://reviews.llvm.org/D153421, we may want to add all of those ops to linalg instead.

Deprecated, as we don't have the TPP dialect any more.