Te Schedule Primitive
tiling
# define the input
A = te.placeholder((m, n), name="A")
# define the output and computation using Lamda expression
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
# create a schedule (including a series of transformation)
s = te.create_schedule(B.op)
# the transformation
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# lower to tensorIR
print(tvm.lower(s, [A, B], simple_mode=True))
The IR module should be:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 10):
if T.likely(i_outer * 10 + i_inner < m):
for j_inner in range(5):
if T.likely(j_outer * 5 + j_inner < n):
cse_var_2: T.int32 = j_outer * 5 + j_inner
cse_var_1: T.int32 = i_outer * 10 + i_inner
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]]
The difference between A_1 and A_2 is:
- A_1 is two-dimension while A_2 is one-dimension.
- They both point to the same underlying data.
- When accessing A_1, we use A_1(i, j), and the underlying transfer (i * stride0 + j * stride1) is done under the language.
- When accessing A_2, we use A_2(i * stride0 + j * stride1), and the transformation is done manually.
fuse
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(xi, yi)
print(tvm.lower(s, [A, B], simple_mode=True))
The TIR is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("stride", "stride"), buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("stride", "stride"), buffer_type="auto")
for i_outer, j_outer, i_inner_j_inner_fused in T.grid((m + 9) // 10, (n + 4) // 5, 50):
if T.likely(i_outer * 10 + i_inner_j_inner_fused // 5 < m):
if T.likely(j_outer * 5 + i_inner_j_inner_fused % 5 < n):
# And then, the j_inner and i_inner should be computed manually
cse_var_2: T.int32 = j_outer * 5 + i_inner_j_inner_fused % 5
cse_var_1: T.int32 = i_outer * 10 + i_inner_j_inner_fused // 5
B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type="auto")
A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type="auto")
B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]]
- This is a fusion of two loops, not an operator fusion.
- And then, the j_inner and i_inner should be computed manually.