tlc-pack/relax

[Discuss] Relax Layout Transformation

psrivas2 opened this issue · 9 comments

Motivation & Goals

Tensor data layout describes how the data is laid out in memory. It determines the memory access pattern and it can significantly impact performance and memory efficiency. Global layout planning thus becomes an important optimization to achieve performance on various hardware backends. The goal of this document is to present the design of global layout planning in Relax.

Terminology

To classify operators at layout perspective, TVM has defined the following terminologies. Note that we do not use these terms in this document, but these can be used in discussions.

  • Layout agnostic: relu, pow etc. These operators are not affected, neither functionality nor performance, by data layouts.
  • Lightly-layout sensitive: pad, concatenate, reduce ops like sum etc. These operators have some attributes that are functionally affected if we do a layout transformation before them. However, performance-wise, the difference is not significant. For these operators, it is beneficial to just adapt to the previous operator output data layout.
  • Heavily-layout sensitive: conv2d, conv2d_transpose etc. These operators are heavily affected, both functionally and performance-wise, by data layouts. They also have data layout as the op attribute. The performance benefit of a layout transformation for these operations, outweighs the runtime cost of performing a layout transformation. Thus, it is beneficial to modify the input data layouts for these operators (if its not a performant data layout), while the rest of layout agnostic and lightly-layout sensitive operators adapt to the layout governed by the output of these heavily-layout sensitive operators.

We introduce and clarify the following terms used in this doc.

  • Layout-critical operations: These correspond to heavily-layout sensitive operations in existing terminology. These operations are highly performance sensitive to operand/result layouts. For example, convolution operations.
  • Frozen layout: Fixed layout decision for an operator. Operators with frozen layouts will be the seed of the following flowing mechanism. Typically, heavily-layout sensitive operators will have frozen layouts.
  • Flow layout constraints: When layout-agnostic and lightly-layout sensitive operators are connected to operators with frozen layouts, they can take those frozen layouts as constraints and adjust themselves without any noticeable performance impact. We call this process “flowing layout constraints”.
  • Layout rewrite: Explicit conversion for layout when two adjacent operators have different layout constraints.
  • Merge/Reduce/Cancel/Fold layout rewrites: When a layout rewrite revert its adjacent layout rewrite (e.g., NCHW->NHWC followed by NHWC->NCHW), we can safely eliminate both.

Prior Art

Relay has the Convert Layout pass which first applies the user specified layout to the operations in the Relay graph. For example, the user may specify that they want all convolution operations to have NHWC layout. The ConvertLayout pass would then transform the IRModule so that user desired layouts are honored and transforms the Relay graph with minimal number of data layout transforms.

It achieves this by requiring each Relay operator to have InferCorrectLayout property. The pass uses this property for layout inference. It looks at the original input layouts and the new input layouts and the InferCorrectLayout **property tells operator needs to be modified to accommodate the new input layouts, and what the new output layouts should be. Layout transforms are inserted where new input layouts differ from incoming layouts. This step is done operator by operator in sequence, where ConvertLayout pass keeps on passing the new layouts to the next operator properties, finally resulting in modifying the whole graph operator-by-operator.

Relay has a long list of InferCorrectLayout methods attached to operators. The logic for a number of operators is quite complex, error-prone and a huge burden to maintain. In some cases, Relay operators cannot be modified to accommodate tiling layout transforms.

Also, Relay performs AlterOpLayout pass to conditionally apply the additional layout transformation (e.g., winograd transformation) given operator, target, etc. As its main usage is more like a secondary optimization that can be applied on the top of global layout planning, we will have a separate design doc to account for it.

Relay Example

Let’s take a look at en example. The Relay graph below preforms a convolution followed by a reduction across H axis.

fn (%x: Tensor[(32, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {
  %0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);
  sum(%0, axis=[2])
}

If we modify the input and output layouts of convolution operation, layout_transform operations are inserted in the graph to maintain correctness. At the same time, the InferCorrectLayout property of sum operation modifies the axis attribute of the operation from axis = [2] to axis = [1] to accommodate changing the layout from NCHW to NHWC.

fn (%x: Tensor[(32, 64, 56, 56), float32] /* ty=Tensor[(32, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(32, 64, 56), float32] {
  %0 = layout_transform(%x, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(32, 56, 56, 64), float32] */;
  %1 = layout_transform(%weight, src_layout="OIHW", dst_layout="HWIO") /* ty=Tensor[(3, 3, 64, 64), float32] */;
  %2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(32, 56, 56, 64), float32] */;
  %3 = sum(%2, axis=[1]) /* ty=Tensor[(32, 56, 64), float32] */;
  layout_transform(%3, src_layout="NWC", dst_layout="NCW") /* ty=Tensor[(32, 64, 56), float32] */
} /* ty=fn (Tensor[(32, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32]) -> Tensor[(32, 64, 56), float32] */

Global Layout Planning in Relax

In Relax we perform layout planning in a very different way as compared to Relay. In Relay the layout rewrites happens at Relay operator level (i.e., relay.nn.split, relay.sum etc.). In Relax we apply layout rewrites after lowering Relax operators to TIR. The access to the underlying compute in TIR for an operation makes it much easier to flow layout constraints in Relax.

Here is a reduction operation in NCHW layout.

@T.prim_func
def sum(input: T.Buffer[(32, 64, 56, 56), "float32"], output: T.Buffer[(32, 64), "float32"]) -> None:
    for i0, i1, i2, i3 in T.grid(32, 64, 56, 56):
        with T.block("reduce"):
            ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3])
            T.reads(input[ax0, ax1, k2, k3])
            T.writes(output[ax0, ax1])
            with T.init():
                output[ax0, ax1] = T.float32(0)
            output[ax0, ax1] = output[ax0, ax1] + input[ax0, ax1, k2, k3]

To transform the above operation from NCHW to NHWC layout, we can simply transform the buffer and block layouts using tir.schedule.transform_layout and tir.schedule.transform_block_layout primitives.

Reduce operation in NHWC layout.

@T.prim_func
def sum(input: T.Buffer[(32, 56, 56, 64), "float32"], output: T.Buffer[(32, 64), "float32"]) -> None:
    for ax0, ax1, ax2, ax3 in T.grid(32, 56, 56, 64):
        with T.block("reduce"):
            v0, v1, v2, v3 = T.axis.remap("SRRS", [ax0, ax1, ax2, ax3])
            T.reads(input[v0, v1, v2, v3])
            T.writes(output[v0, v3])
            with T.init():
                output[v0, v3] = T.float32(0)
            output[v0, v3] = output[v0, v3] + input[v0, v1, v2, v3]

In an ideal scenario, layout planning would figure out the most efficient layout for all operations to minimize the end-to-end cost of graph. However, we are not there yet. For now, we need help either from user or auto tuning system like Metaschedule to identify layout-critical operations in the graph and choose the best layout for them greedily. This would result in a graph where layouts for some operations are frozen (i.e., layouts must not be modified by other passes), and layout rewrite operations added to operands and results edges of these layout-critical operations. Next, an optimization pass would flow these layout rewrites in the graph to reduce the overall cost of such copy operations.

Thus, in Relax we break down the layout planning into two sub problems, which are described separately in the document below.

  • Tune layout-critical operations and freeze layouts. A pass to identify layout-critical operations and find best schedule for them greedily. Mark these operations as having “frozen layouts” and add layout rewrite operations on operands and results to maintain correctness. Such a pass could work with tuning frameworks like Metaschedule to find the best schedule for layout-critical operations.
  • Minimize layout rewrite cost in graph. An optimization pass to reduce the cost of layout rewrites in the graph. This would be achieved by flowing layout rewrites in the graph to facilitate fusing/folding them with other layout rewrites, operators, or constants.

Tune Layout-Critical Operations & Freeze Layouts

This can be done in conjunction with tuning frameworks like Metaschedule, or a pass that allows user to annotate layouts of specific operations.

  • Using Metaschedule: Metaschedule could be used to find the best schedule of layout-critical operations greedily. It would add layout rewrite blocks in the prologue and epilogue of TIR function for the operation. A HoistLayoutRewritePass (discussed in appendix) would be added as part of this proposal to pull out the layout rewrites into separate TIR functions and mark the layout-critical TIR function with “frozen layout” attribute.
  • User annotated layout: A pass could take user specification of preferred operand-result layouts for relax operations. The pass would insert layout rewrite operations on operand & result edges of these operations if they conflict. It will also mark these operations as having frozen layouts. Such a pass would be implemented as part of layout planning in Relax, but it is out of scope for this design doc which focusses on the problem of minimizing the cost of layout rewrites in the graph. A separate design document would discuss the API and strategy for such a pass.
  • Other. To promote other strategies of tuning layout-critical operations, any other strategy which tunes specific operations and freezes their layouts is also valid. It is expected that layout rewrite operations would be added by such a strategy. It can then use the layout planning pass to minimize the cost of layout rewrites in the graph.

Minimize Layout Rewrite Cost in Graph

The minimize layout rewrite cost problem can be formulated as below:

Given a directed acyclic graph G(V, E), where v $\in$ V is an operation and e $\in$ E are data values flowing through G, a subset of operations are marked as having frozen layout i.e., the operand & result layouts of the operation cannot be modified. The graph contains layout rewrite operations to satisfy the layout constraints of operations, constants, graph inputs and outputs. The goal of layout planning is to minimize the cost of layout rewrite operations in the graph.

Layout Representation

The layout rewrites in Relax can be represented as any of the following two ways.

  • PrimFunc Representation. Layout rewrites will be represented as TIR functions in the IRModule. This has the benefit of easy serialization/deserialization of these constraints without the introduction of any new structure in the IR. For example, following is a layout rewrite in Relax. Note that not all TIR functions that do spatial transformation of a buffer would be candidates that can flow through operations and optimized away. Only a subset of such PrimFuncs, whose transformation (or inverse) can be represented by a compact IndexMap representation would be candidates which can flow and be optimized. While IndexMap is a nice compact representation for expressing layout rewrites, it also helps to use the same layout representation that is used by tir.schedule.transform_layout API.

    @T.prim_func
    def layout_transform(arg: T.Buffer[(32, 64, 224, 224), "float32"], out: T.Buffer[(32, 16, 224, 224, 4), "float32"]) -> None:
        for i0, i1, i2, i3, i4 in T.grid(32, 16, 224, 224, 4):
            with T.block("T_layout_trans"):
                ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
                T.reads(arg[ax0, ax1 * 4 + ax4, ax2, ax3])
                T.writes(out[ax0, ax1, ax2, ax3, ax4])
                out[ax0, ax1, ax2, ax3, ax4] = arg[ax0, ax1 * 4 + ax4, ax2, ax3])
  • Compact Layout Representation: While the PrimFunc representation is very useful in representing layout rewrites generally, it could be difficult to flow/fuse them in their general form. To flow layout rewrites through an operation, a compact representation of layout rewrite is very useful as it could be applied to the operand/result buffers. The compact representation could (1) make it easier to fuse/fold/cancel layout rewrites, and (2) easily be transformed to tir.schedule primitives to actually modify layouts of TIR buffers. We propose to have the compact form of layout rewrites be represented through Relax operations.

    • Relax operation representation. The relax operation representation of layout rewrite would use the following relax operations.
      • relax.layout_transform(input: Tensor, index_map: lambda): This maps the input to a new iteration space. The index_map defines the mapping function. This is a pure layout transform, i.e., index_map is a bijective function.
      • relax.pad(input: Tensor, pad_width: Tensor, pad_value: scalar): Inserts pad_value to the given pad_width locations. pad_width is an integer tensor with shape [n, 2], where n is the rank of input. For each dimension d of inputpad_width[d, 0] indicates how many values to add before the contents of input in that dimension, and pad_width[d, 1]indicates how many values to add after the contents of input in that dimension.
      • relax.crop(input: Tensor, start_indices, slice_sizes, cropped_value): Crops the tensor specified in start_indices and slice_sizes. The optional cropped_value is a hint to the compiler about the values stored in the input tensor regions that were cropped away. This is useful information for the compiler if it intends to cancel this relax.crop with a following relax.pad.

It should be easy to roundtrip between the PrimFunc representation and compact relax operation representation of layout rewrites. There are some valid questions/concerns that we aim to address on the choice of compact representation of layout rewrites:

  1. Since relax.layout_transform only supports bijective transformation, would general PrimFunc layout rewrites (with implicit padding/cropping) have to be broken into smaller primitive TIR blocks in the input IRModule (i.e., Does the layout planner expect prior passes to break general PrimFunc layout rewrites into smaller TIR blocks?)

    Answer: No. The input IRModule to layout planning pass could use general PrimFunc representation. We can break such PrimFunc blocks into primitive pad/crop and bijective layout rewrites upon conversion to compact form. So there is no constraint on having only bijective layout rewrites/pad/crop TIR blocks by previous passes. Furthermore, layout rewrite TIR blocks within a Primfunc could be annotated with compact representations (may be use IndexMap representation). HoistLayoutRewritePass could use these compact representations instead of recovering the compact form from TIR block.

  2. Is it necessary to break the layout rewrites into these primitive relax operations (bijective transformation, pad, and crop) for easier cancellation? Can we fuse two layout rewrites in the PrimFunc representation and prove that it is identity?

    Answer: Probably no. If we can have such an analysis in TIR that can prove that a PrimFunc is identity for most of our use cases, that would lower the need to have these primitive representation. A compact representation (relax.layout_transform that supports non-bijective rewrites) would still be needed, but we won’t need to break down general layout rewrite into these primitives.

Example

To minimize the cost of layout rewrites in the graph, a pass can flow them across operations until they they can be fused into a constant, cancel out or fuse with other layout rewrites. For instance, in the following graph (which we will use as a running example), we have conv-->add-->conv graph. The convolution operations have “frozen layouts” (marked in orange). The layout rewrite operations are marked in green.

graph LR
    X((x)) --NCHW--> Transform_x(to NCHWc) --NCHWc--> Conv(conv)
		Conv --NCHWc--> Transform_r(to NCHW) --NCHW--> Add(add)
    F((f)) --OIHW--> Transform_f(to OIHWio) --OIHWio--> Conv
		Bias((bias)) -- NCHW --> Add -- NCHW --> Transform_o(to NCHWc) --NCHWc--> Conv2(conv)
		Conv2 --NCHWc--> Transform_conv(to NCHW) --NCHW--> output((result))
		subgraph _
			Transform_r
			Bias
			Transform_o
			Add
		end
		classDef orange fill:#f96,stroke:#333,stroke-width:4px;
		classDef green fill:#16b522,color:#fff
		classDef green_highlight fill:#16b522,stroke:#f96,stroke-width:4px,color:#fff
		class Conv orange
		class Conv2 orange
		class Transform_x green
		class Transform_r green
		class Transform_f green
		class Transform_o green_highlight
		class Transform_conv green
Loading
graph LR
    X((x)) --NCHW--> Transform_x(to NCHWc) --NCHWc--> Conv(conv)
		Conv --NCHWc--> Transform_r(to NCHW) --NCHW--> Transform_rinv(to NCHWc)--NCHWc--> Add(add)
    F((f)) --OIHW--> Transform_f(to OIHWio) --OIHWio--> Conv
		Bias((bias)) -- NCHW --> Transform_o(to NCHWc) --NCHWc--> Add --NCHWc--> Conv2(conv)
		Conv2 --NCHWc--> Transform_conv(to NCHW) --NCHW--> output((result))
		subgraph _
				Transform_r
				Transform_rinv
				Bias
				Transform_o
			Add
		end
		classDef orange fill:#f96,stroke:#333,stroke-width:4px;
		classDef green fill:#16b522,color:#fff
		classDef green_highlight fill:#16b522,stroke:#f96,stroke-width:4px,color:#fff

		class Conv orange
		class Conv2 orange
		class Transform_x green
		class Transform_r green
		class Transform_f green
		class Transform_o green_highlight
		class Transform_conv green
		class Transform_rinv green_highlight
Loading
graph LR
    X((x)) --NCHW--> Transform_x(to NCHWc) --NCHWc--> Conv(conv)
		Conv --NCHWc---> Add(add)
    F((f)) --OIHW--> Transform_f(to OIHWio) --OIHWio--> Conv
		Bias((bias*)) --NCHWc--> Add --NCHWc--> Conv2(conv)
		Conv2 --NCHWc--> Transform_conv(to NCHW) --NCHW--> output((result))
		subgraph _
			Bias
			Add
		end
		classDef orange fill:#f96,stroke:#333,stroke-width:4px;
		classDef green fill:#16b522,color:#fff
		class Conv orange
		class Conv2 orange
		class Transform_x green
		class Transform_r green
		class Transform_f green
		class Transform_o green
		class Transform_conv green
		class Transform_rinv green
Loading

The number of layout rewrites can be reduced by flowing the to NCHWc layout rewrite across add operation from result to operands and then simplifying the graph by folding adjacent layout rewrites (in this case into an identity), and fusing the layout rewrite into a constant.

Folding/Fusing Layout Rewrites

The layout rewrites can be reduced through the following mechanisms

  • Folding adjacent layout rewrites: Adjacent layout rewrites can be folded into each other. In terms of relax operation representation, this would mean:
    • two relax.transform_layout operations being folded by folding their index maps
    • relax.pad operation being folded into relax.crop operation. Folding these two ops is a bit subtle and must obey the following rules, assuming the indices being padded and cropped in the two operations are the same.
      • a relax.crop operation can always be folded into a prior relax.pad operation.
      • a relax.pad operation can only be folded into a prior relax.crop operation if all the cropped values are T.undef or same as the values being padded.
  • Folding layout rewrite into a constant: A layout rewrite on the output edge of a constant can be folded into the constant itself. For example, to NCHWc layout rewrite operation was folded into the bias constant in the above example.
  • Folding layout rewrite into an operation: Layout rewrites can also be folded into an operation, i.e., modify the layout on one buffer but leave the rest of the operands/results as is. It is unclear though if this could lead to a more optimized graph. In many cases, it would lead to inefficient access patterns in the loops for operands and results (see appendix). For such cases, we can leave the operation + layout rewrite as is in the graph to be optimized by a tuning pass later. The tuning pass can choose to fold the layout rewrite into the operation or materialize it as it sees fit.

Direction of Flowing Layout Rewrites

In order to facilitate folding layout rewrites into other layout rewrites and constants, we intend to be able to flow layout rewrites across operations. In the example above, the layout rewrite post add operation was flowed through it (result to operands) thus allowing it to be folded into an inverse layout rewrite and constant.

The flow of layout rewrites presents us with a design choice on the direction of flow. The following two options were considered.

  • F0: Operands to Results (Forward). Flow layout rewrite from an operand to results.
    • Pros
      • Operations with single operand and multiple results (for example split) would not mess up the access patterns of output buffers.
    • Cons
      • If an operation has multiple operands, then flowing the rewrites through it, could mess up the access patterns of other operands.
      • We cannot fold layout rewrites into constants in this flow, as constants are operations with zero inputs and one output.
  • F1: Results to Operands (Backward). Flow layout rewrites from a result to operands.
    • Pros
      • In ML domain, many operations would fall into the category of multiple operands and single result. So this approach is unlikely to mess up the access pattern of other buffers.
      • We can expect that in this flow, some of the layouts could be folded into constants.
    • Cons
      • If an operation has multiple results, this direction of flow can mess up the access pattern of other outputs.

Acknowledging that none of the two choices are strictly better than the other, the arguments for F1 seem more useful for a large class of operations. So for that reason, we would prefer F1.

Flowing Layout Rewrite through an Operation

For the generic support for flowing layout rewrites, we propose two flow mechanism for each abstraction-level. With these two approaches, we believe our approach can support any IRModule, which may contain a mixture of Relax operators and TIR functions, within the compilation pipeline.

Flowing Layout Rewrite through an Operation at TIR-level

This section would answer the question, “How to generate layout rewrite on the operand buffers of an operation, when flowing layout rewrite through it?”

An analysis of TIR block(s) where the result buffer is written to, can help us identify the layout rewrites for operands.

Let’s say the TIR PrimFunc has one result (output) and multiple operands (arg0, arg1, …). A layout rewrite lambda N, C, H, W: (N, H, W, C // 4, C % 4) is applied on the output buffer. Our goal is to identify the layout rewrites on all of the operands.

  1. Find the block B that writes to output.
  2. Inspect the [T.read](http://T.read) and T.write signatures on B. For output buffer the access would be output[ax0, ax1, ax2, ax3]. Using the mapping N = ax0, C = ax1, H = ax2, W = ax3 apply the layout rewrite to each of the buffers in T.read accesses.

In the presence of temporary buffer allocation in TIR PrimFunc, we might have to flow the layout rewrites through multiple blocks up to the operands.

The following examples show how this is done for broad classes of operations.

Elementwise Operations.

@T.prim_func
def relu(input: T.Buffer[(32, 3, 224, 224), "float32"], output: T.Buffer[(32, 3, 224, 224), "float32"]) -> None:
    for i0, i1, i2, i3 in T.grid(32, 3, 224, 224):
        with T.block("compute"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(input[i0_1, i1_1, i2_1, i3_1])
            T.writes(output[i0_1, i1_1, i2_1, i3_1])
            output[i0_1, i1_1, i2_1, i3_1] = T.max(input[i0_1, i1_1, i2_1, i3_1], T.float32(0))

In the above PrimFunc, let’s say the output has a layout rewrite lambda N, C, H, W: (N, C // 4, H, W, C%4). We want to flow this rewrite from output to input. output buffer is written in block compute. The signature of the block has information on the buffers it reads and writes. Here it writes to buffer output[i0_1, i1_1, i2_1, i3_1] and reads buffer input[i0_1, i1_1, i2_1, i3_1]. Since the access to both these buffers are identical, the layout rewrite would be identical too. Thus, the layout planning would modify both input and output buffers with the same layout rewrite as it flows the rewrite operation through relu op.

Broadcast Operations.

@T.prim_func
def add(input: T.Buffer[(32, 256, 213, 213), "float32"], bias: T.Buffer[(256, 1, 1), "float32"], output: T.Buffer[(32, 256, 213, 213), "float32"]) -> None:
    for i0, i1, i2, i3 in T.grid(32, 256, 213, 213):
        with T.block("T_add"):
            ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(input[ax0, ax1, ax2, ax3], bias[ax1, 0, 0])
            T.writes(output[ax0, ax1, ax2, ax3])
            output[ax0, ax1, ax2, ax3] = input[ax0, ax1, ax2, ax3] + bias[ax1, 0, 0]

Let’s say the output has a layout rewrite operation lambda N, C, H, W: (N, C // 4, H, W, C%4). Our goal is to flow the layout rewrite operation from result to operands. We know output buffer is written in block T_add. The signature has the access patterns for output[ax0, ax1, ax2, ax3], input[ax0, ax1, ax2, ax3] and bias[ax1, 0, 0] buffers. Mapping the axes C == ax1 and applying the layout rewrite we know that the layout rewrites for input and bias buffers.

  • input: lambda N, C, H, W: (N, C // 4, H, W, C%4)
  • bias: lambda i, j, k: (i // 4, 0, 0, i % 4)

Reduction Operations.

@T.prim_func
def sum(input: T.Buffer[(32, 256, 213, 213), "float32"], output: T.Buffer[(32, 256), "float32"]) -> None:
    for i0, i1, i2, i3 in T.grid(32, 256, 213, 213):
        with T.block("rxplaceholder_red"):
            ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3])
            T.reads(input[ax0, ax1, k2, k3])
            T.writes(output[ax0, ax1])
            with T.init():
                output[ax0, ax1] = T.float32(0)
            output[ax0, ax1] = output[ax0, ax1] + input[ax0, ax1, k2, k3]

Let’s say the output has a layout rewrite operation lambda N, C: (N, C // 4, C%4). Our goal is to flow the layout rewrite operation from result to operand. We know output buffer is written in block rxplaceholder_red. The signature has the access patterns for output[ax0, ax1] and input[ax0, ax1, k1, k2] . Mapping the axes C == ax1 and applying the layout rewrite we know that the layout rewrites for input buffer will be lamda i, j, k, l: (i, j // 4, k, l, j % 4)

Fused Operations.

Fused operations could have the added complication of having a series of blocks. We can use the dependency analysis on buffers using block signatures to propagate the layout rewrite rules from result to operands.

Flowing Layout Rewrite through an Operation at Graph-level

Oftentimes, layout rewrites need to flow at graph-level before lowering to TIR-level. BYOC is a good example - (1) BYOC may have certain layout constraints. (2) BYOC codegen works at graph-level by converting each relax op to BYOC equivalent. Relay solves this problem by introducing InferCorrectLayout to provide the manual guidance. Unfortunately, it has been source of many tricky issues (e.g., apache/tvm#10118, apache/tvm#10156, apache/tvm#12007).

To overcome this problem, we propose the following the graph-level flow mechanism that peeks at PrimFunc implementation while leveraging the powerful TIR-level analysis:

  1. Load the simplest PrimFunc implementation for an operator
  2. Flow the constraints at TIR-level. In other words, transform Primfunc to conform the given layout.
  3. Update the operator based on the transformed PrimFunc

Step 3 might be tricky with the current PrimFunc design since we lose the convenient access to op-level info (e.g., operator name, attributes, etc.) at TIR-level; to be clear, PrimFunc implicitly embeds those information but it may not be easy to extract them. Based on our investigation, it seems possible to extend PrimFunc to maintain the op-level information in an accessible way. It also reveals that we might be able to maintain this information when flowing transpose-like layout rewrites through them. Note that BYOC use case only requires transpose-like layout rewrites (e.g., NCHW to NHWC) which is easier to support than layout rewrites with implicit tiling/padding (e.g., NCHW to NCHWc)

Fallback: Layout Rewrite Transformation Callback

Although the above sections cover most of the scenarios, in some cases the user might want to have explicit control over how a relax operator or PrimFunc should be modified when flowing layout rewrites through it. For example, the PrimFunc could have an opaque computation, making it hard to figure out the layout rewrites on operands from result layout rewrites. For such cases, an easy way to register such a callback function on an operation would be provided. Similar to FTVMConvertLayout property in Relay, it would allow user defined alterations to the operation when flowing layout rewrites. When registered, the callback would be used instead of the analyses mentioned in the previous sections.

Advantage of Relax Layout Planning over Relay

The approach described in this doc has multiple advantages over layout planning in Relay.

  • Robustness: Unlike Relay, we do not need to write and maintain code to tell the pass how to flow layout rewrites through an operation. InferCorrectLayout property currently has many lines of code with complex logic which is error-prone and huge burden to maintain. In the new layout planning pass, all of this code would not be needed which is huge win. The strategy described above is much more robust and can handle a broad class of PrimFuncs even when they do not correspond to operations in Relax operator system.
  • Support for more general layout transformations: Relay deals with string representation of layout transformations. There is also implicit meaning attached to dimension symbols - H represents height, W represents width, etc. This results in code that could easily break. Relax layout planning does not use strings for layout representation thus avoiding the pitfalls of the Relay and robust against general layout transformations.
  • Easier operator registration: In Relay, users have to manually define InferCorrectLayout for their new operators, which is not necessarily easy to figure out. However, as Relax layout planning does not rely on this, operator registration becomes much easier.
  • Cleaner pipeline, improved debuggability and customizability: In Relay, ConvertLayout and OpStrategy are tightly coupled to each other: starting from user annotation in ConvertLayout and then during lowering, OpStrategy will find the right implementation accordingly by using the layout information. As OpStrategy is a complicated inflexible component that lives outside of pass infra, it makes layout optimization hard to debug and customize in many occasions. On the other hand, in Relax, end-to-end flow can live in the pass infra which can be easier to debug and customize. You can perform layout planning at any stage in the pipeline before codegen, even when you have a partially lowered IRModule.

Appendix

Hoist Layout Rewrites Pass

This pass inspects TIR function and lifts any layout rewrite blocks before and after the computation into separate TIR functions. It could either use some analysis of blocks in the PrimFunc or use explicit block annotations to identify layout rewrite blocks within a PrimFunc.

For example, the layout_rewrite block in the matmul PrimFunc below would be lifted out into a separate PrimFunc after this pass.

@R.function
def main(x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32")) -> Tensor((16, 16), "float32"):
    gv0 = R.call_tir(matmul, (x, w), (16, 16), dtype="float32")
    return gv0

@T.prim_func
def matmul(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
    B_ = T.alloc_buffer([16, 4, 4], dtype="float32")
    for i0_o, i1_o in T.grid(16, 16):
        with T.block("layout_rewrite"):
            i0, i1 = T.axis.remap("SS", [i0_o, i1_o])
						# optional attribute annotation to identify layout rewrite blocks
            T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
            B_[i1, i0 // 4, i0 % 4] = B[i0, i1]
    for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
        with T.block("matmul"):
            vi = T.axis.spatial(16, i0 * 4 + i1)
            vj = T.axis.spatial(16, j)
            vk = T.axis.reduce(16, k0 * 4 + k1)
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B_[vj, vk // 4, vk % 4]
@R.function
def main(x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32")) -> Tensor((16, 16), "float32"):
    gv = relax.call_tir(layout_rewrite, (w,), (16, 4, 4), dtype="float32")
    gv0 = relax.call_tir(matmul, (x, gv), (16, 16), dtype="float32")
    return gv0

@T.prim_func
def matmul(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 4, 4), "float32"], C: T.Buffer[(16, 16), "float32"]):
    for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
        with T.block("matmul"):
            vi = T.axis.spatial(16, i0 * 4 + i1)
            vj = T.axis.spatial(16, j)
            vk = T.axis.reduce(16, k0 * 4 + k1)
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk // 4, vk % 4]

@T.prim_func
def layout_rewrite(SRC: T.Buffer[(16, 16), "float32"], TGT: T.Buffer[(16, 4, 4), "float32"]):
    for i, j in T.grid(16, 16):
        with T.block("layout_rewrite"):
            i0, i1 = T.axis.remap("SS", [i, j])
            TGT[i1, i0 // 4, i0 % 4] = SRC[i0, i1]

Thank you everyone for the discussion yesterday and bringing up very important points. I'll summarize them below.

Q1. [@Hzfengsy] Would layout planning interfere with fusion?
For the usecases we have in mind - MetaSchedule and User annotation, we do expect that fusion has already happened before layout planning. Specifically I see the following phase ordering: Fusion --> MS Tuning individual PrimFuncs (freezes layouts) --> HoistLayoutRewritePass --> RelaxLayoutPlanning Pass. So in these scenarios we do not expect that layout planning would prevent fusion in any way.

Another approach to tackle this would be to pick forward and backward flow based on user choice. Technically the infra to flow layout rewrites should work regardless of the chosen direction of flow. But I do not have a phase ordering example where that would be necessary and there. Can you elaborate on the use case you had in mind where layout planning and fusion would interfere?

Q2. [@Hzfengsy] Can we expose an API at Relax level to easily transform layouts of PrimFuncs and operators?
Yes, perhaps we can provide an API to modify the layout of a PrimFunc similar to how tir.schedule.transform_layout and tir.schedule.transform_block_layout are used to transform buffers and blocks. We can do that as part of this work.

Q3. [@slyubomirsky] Can you share an example of how layout rewrites would flow through operators?
@sunggg has expanded on this here #278

Q4. [@jinhongyii ]How will you handle layout rewrite flow through pad operation?
Pad operation is itself a layout rewrite. So it will follow similar behavior to other layout rewrites. For example, two successive pad operations could possibly be fused into each other. Or may be it can be cancelled out with an adjacent crop operations. I expand on when this would be legal in the original post.
It is possible I misunderstood your question. Can you expand on it?

Given broad interests, let me add clarification about how relax layout planner would work with BYOC. Although there could be better approaches with the future improvement in BYOC, in this post, I will explain with the current BYOC approach to focus on delivering the main idea around the interaction between layout planner and BYOC.

For each external library/codegen, BYOC interface manages a set of specifications about its supported operators and their constraints. For instance, TensorRT supports conv2d only when data_layout==NCHW && kernel_layout==OIHW.
It is also worth noting that certain BYOC, such as TensorRT, supports conversion of layout_transform while some, such as DNNL, do not.

Largely, offloading to BYOC happens in several steps:

// Step 1: Identify valid operators that we can offload
// `pattern_table()` stores the set of operators and their constraints
// If an operator cannot satisfy the check, it won't be annotated in the next step. 
transform.MergeComposite(pattern_table())

// Step 2: Annotate each of valid operators
// This example assumes `tensorrt` 
transform.AnnotateTarget("tensorrt")

// Step 3: Try to merge annotations to minimize the number of subgraphs to offload 
// This is important for performance as invocation of a BYOC runtime module would impose some overhead
transform.MergeCompilerRegions()

// Step 4: Partition the graph according to annotation
transform.PartitionGraph()

To best use BYOC, Relay expects users to convert layouts preferably before the offloading steps above.

transform.ConvertLayout(
                {
                    "nn.conv1d": ["NCW", "default"],
                    "nn.conv2d": ["NCHW", "default"],
                    "nn.conv3d": ["NCDHW", "default"],
                    "nn.conv2d_transpose": ["NCHW", "default"],
                }
            ),
// Offloading step 1-4
transform.MergeComposite(pattern_table())
transform.AnnotateTarget("tensorrt")
// ....

transform.ConvertLayout essentially freezes the layouts of specified operators with user-provided layouts, flow & cancel layout ops at operator level by using InferCorrectLayout property. Since Relax layout planner proposes its own way to flow & cancel, we expect it would be possible to implement the equivalent of transform.ConvertLayout at least while being open-minded about new ideas to attack this BYOC-layout problem.

In case you want to offload only a part of graph (although the graph is fully offload-able), you can customize the annotation by selectively adding/removing the annotation. Please note that the annotated operator nodes will be ignored by internal pipeline by delegating optimization & codegen to external components. By doing so, you may offload one part of graph to BYOC while still leveraging optimization passes like fusion, metaschedule tuning, etc for another.

Would you be able to use ConvertLayout to fix only specific operators, namely the ones you intend to match in BYOC? Perhaps that could be added as an interface to BYOC (e.g., fix layouts at each pattern match)

@slyubomirsky Yeah, that seems possible. Actually, that is one of the potential improvements discussed offline with @YuchenJin and @psrivas2 as well.

@psrivas2 I thought that flowing layout through pad will result in some different behavior, but it proved that I'm wrong. So Q4 is no longer a problem.

What's the status of this?

@spectrometerHBH we are implementing this. We will be sending out PRs starting next week for review.

Given broad interests, let me add clarification about how relax layout planner would work with BYOC

I'm interested in tackling this problem. I've worked on two BYOC backends recently (DNNL and CUTLASS), they both want layouts that are likely different from the input mod (NCHWc or NHWC).

Probably the first step is to add a high level layout_transform op (UPDATE: just found that the proposal talks about relax.layout_transform op). We also need to make it executable via legalizer.

@masahi that would be awesome! @sunggg also proposed a direction here. Would be great to get alignment on the approach.

UPDATE: just found that the proposal talks about relax.layout_transform op). We also need to make it executable via legalizer.

Yes, I have a pending PR locally that adds the layout_transform op. It does not add legalization support though. Let me send it out today.
Note: Relax already has ops like permute_dim and reshape ops, which could be used to handle layout transformation without any implicit padding and cropping.