Auto memory planning
jinhongyii opened this issue · 0 comments
Background
We want to introduce auto data movement, that is, automatically generate schedule for copying a region of data to another. This will involve automatically deciding the layout for the intermediate buffers and automatically binding threads.
We first pay attention to the layout for intermediate buffers and completely ignore bindings.
Process
A super simple idea is to have a list of layouts and sample one of these from the list. The list can be
f(i, j)= (i, j)
f(i, j)= (j, i)
f(i, j)= pad m elements every n lines
f(i, j)= (i, i ^ j)
...
Suppose the warp memory has an identical layout. Then the process can be defined as below:
The original IR is:
with tir.block([1024,1024,tir.reduce_axis(1024)]) as [vi, vj, vk]
C[vi, vj]+=A[vi, vk]*B[vk, vj]
after early tensorize:
for i, j, k in tir.grid(64, 64, 64):
with tir.block([64,64,tir.reduce_axis(64)]) as [vi, vj, vk]
for ii, jj, kk in tir.grid(16, 16, 16):
with tir.block(16, 16 tir.reduce_axis(16)] as [vii, vjj, vkk]
tir.bind(...)
C[vii, vjj]+=A[vii, vkk]*B[vkk, vjj]
after multi-level tiling (no cache_read/cache_write)
for i0, j0 , ... in tir.grid(...): #SSSRRSRS
with tir.block([64,64,tir.reduce_axis(64)]) as [vi, vj, vk]
tir.bind(...)
for ii, jj, kk in tir.grid(16, 16, 16):
with tir.block(16, 16 tir.reduce_axis(16)] as [vii, vjj, vkk]
tir.bind(...)
C[vii, vjj]+=A[vii, vkk]*B[vkk, vjj]
rewrite the computation part , get the warp load/store tensor intrin and generate intermediata buffer whose layout is constrained by the tensor intrin of warp load :
for i0, j0, i1, j1, i2, j2 in tir.grid(...): #SSS
for k0:#R
A->A_shared(sampled layout f)
B->B_shared(sampled layout g)
for k1, i3, j3, k2:#RSR
A_shared[i,i^j]->A_warp
B_shared->B_warp:
for i4, j4: #S
wmma
wmma->C
tensor rewrite for warp load/store:
for i0, j0, i1, j1, i2, j2 in tir.grid(...): #SSS
for k0:#R
A->A_shared(sampled layout f)
B->B_shared(sampled layout g)
for k1, i3, j3, k2:#RSR
wmma_load_sync
for i4, j4: #S
wmma
wmma_store_sync
Note there are some potential problems of this algorithm:
Problem 1. how can we ensure the layout can fit the pre-registered tensor intrin (warp load).
For example, if we have a layout that pad 4 elements every two lines and the shared memory size of A is 128*64
f(i,j) = i * 64 + j + i / 2 * 4
If we have the pre-registered tensor intrin as wmma_load_sync, then the layout cannot be applied because the function requires a consistent stride as its arguments:
//here `ldm` represents the stride
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
However, when we have a pre-registered tensor intrin as the ptx mma instruction:
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];
.shape = {.m8n8};
.num = {.x1, .x2, .x4};
.ss = {.shared};
.type = {.b16};
It passes in the pointer to the head of each line, which doesn't have such constraint of strides, so the layout can be applied in this situation.
So we should have a description of the constraints which the intrin impose on the layouts for every tensor intrin or we can specify exactly what layout a tensor intrin accepts.
Problem 2. how to describe the tensor intrin after introducing layouts.
Previously we are dealing with a simple data movement, that is copying data from A_shared[i,j] to A_warp[i,j], which is 16 * 16 to 16 * 16.
@tvm.script.tir
def wmma_load_a_desc(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float16", align=128, offset_factor=256,
scope="shared")
C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256,
scope="wmma.matrix_a")
with tir.block([16, 16], "root") as [vi, vj]:
tir.bind(vi, 0)
tir.bind(vj, 0)
for i, j in tir.grid(16, 16):
with tir.block([16, 16], "load") as [vii, vjj]:
tir.bind(vii, vi + i)
tir.bind(vjj, vj + j)
C[vii, vjj] = A[vii, vjj]
@tvm.script.tir
def wmma_load_a_impl(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float16", align=128, offset_factor=256, scope="shared")
C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256, scope="wmma.matrix_a")
with tir.block([16, 16], "root") as [vi, vj]:
tir.bind(vi, 0)
tir.bind(vj, 0)
tir.reads(A[0: 16, 0: 16])
tir.writes(C[0: 16, 0: 16])
tir.evaluate(tir.tvm_load_matrix_sync(
C.data, 16, 16, 16, C.elem_offset // 256, A.access_ptr("r"), 16, "row_major",
dtype="handle"))
However, if we use several different layouts, there will different descriptions: A_shared[f(i,j)]->A_warp[i,j]. for each different f, we need a different implementation so as to do tensorize rewrite. Take the layout discussed in the first problem as example.
the description and implementation would be expected to be below, but actually it can't be done.
@tvm.script.tir
def mma_load_a_desc(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (256), "float16", align=128, offset_factor=256,
scope="shared")
C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256,
scope="mma.matrix_a")
with tir.block([16, 16], "root") as [vi, vj]:
tir.bind(vi, 0)
tir.bind(vj, 0)
for i, j in tir.grid(16, 16):
with tir.block([16, 16], "load") as [vii, vjj]:
tir.bind(vii, vi + i)
tir.bind(vjj, vj + j)
#the layout can't be expressed in 2-d way
#this example is wrong because the `stride` can't be inferred
C[vii, vjj] = A[vii * stride + vjj + vii / 2 * 4]
@tvm.script.tir
def mma_load_a_impl(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (256), "float16", align=128, offset_factor=256, scope="shared")
C = tir.match_buffer(c, (16, 16), "float16", align=128, offset_factor=256, scope="mma.matrix_a")
with tir.block([16, 16], "root") as [vi, vj]:
tir.bind(vi, 0)
tir.bind(vj, 0)
tir.reads(A[0: 16, 0: 16])
tir.writes(C[0: 16, 0: 16])
tir.evaluate(tir.ldmatrix(C.data, 16, 16, 16, C.elem_offset // 256, A.access_ptr("r"),2, 4, 64))
# this is merely an example. ldmatrix is currently not supported
Another real problem is that if there are m layouts and n tensor intrins. we'll have m*n description-implementation pair.
I will post some intermediate TIR later to clarify the transformations.