plaidml/tpp-mlir

Create a "target description" class for target-specific decisions

Closed this issue · 4 comments

Today, we make wrong packing decisions based on types (ex. bf16 always means vnni) instead of target support.

We also make a compile-time decision about the packing shapes, which is also wrong in general.

Since we're lowering into library calls, we need a link into target libraries to tell us what they support, not just what the targets actually support.

A set of classes that return supported packing types, shapes, extensions would do the trick. Once lowered to our tpp dialect, the lowering directly matches the supported architecture.

Discussing this with @chelini and @hfp, the issue here is to separate "libxsmm" from "x86" and "arm", for the compiler decisions (tiling, loop order, passes) and to pass the info down to the targets below. We do not, however, want to re-implement what LLVM already gives us.

LLVM has target descriptors for all supported CPUs and we already create them in tpp-run via the --cpu flag (like clang). But LLVM itself does not know how to detect the host's arch. That logic is performed by the Clang driver and we really do not want to replicate that.

LIBXSMM has target detection (via cpuid, etc) and uses it for its run time JIT compilation, but the API is limited and may not be present in a compiler that does not use libxsmm (but lowers directly). It also may not be 1:1 with what MLIR/LLVM represents, even for CPUs, so it wouldn't be a generic approach that could one day be upstreamed.

@chelini proposed we use https://github.com/pytorch/cpuinfo, which is the same as IREE uses, and is a generic enough library that could work. However, it would have the same problem as libxsmm regarding upstreaming. Since we already depend on libxsmm, we should try to use it until we can't.

But this is just the host detection part. We also need a way to convey target information (via -cpu flags) and pass that down to the MLIR compiler (for tiling) and the LLVM compiler (for vectorization) and the library implementation (for example MKL) if we intermix compiler generated code with ninja written library kernels.

I strongly agree with:

However, it would have the same problem as libxsmm regarding upstreaming. Since we already depend on libxsmm, we should try to use it until we can't.

That means libxsmm should return tiling information etc. as even a split into x86/arm is not sufficient at all, there is AVX2, AVX512 (one and two units), AMX, core count. Similar nuances are true on ARM, so until we have a code gen which is capable of libxsmm performance, we don't need to cross the bridge of generalization.