From 400baf2b32d460497c3c65ebb50666536783d49e Mon Sep 17 00:00:00 2001 From: Adam Straw Date: Tue, 31 Aug 2021 08:57:30 -0700 Subject: [PATCH] refactor optimize GEMM on CPU tutorial (#8825) * refactor optimize GEMM on CPU tutorial * fix lint errors * fix more lint errors * fix typo * fix problem with redefinition of `k` add TODO and comments around loop unrolling clarify note on the array packing figure * reword general description of array packing * grap kaxis from compute definition * remove duplicate comments on unrolling --- tutorials/optimize/opt_gemm.py | 133 ++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 61 deletions(-) diff --git a/tutorials/optimize/opt_gemm.py b/tutorials/optimize/opt_gemm.py index 7af772784cd6..5d698c612ee8 100644 --- a/tutorials/optimize/opt_gemm.py +++ b/tutorials/optimize/opt_gemm.py @@ -101,7 +101,7 @@ k = te.reduce_axis((0, K), "k") A = te.placeholder((M, K), name="A") B = te.placeholder((K, N), name="B") -C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C") +C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C") # Default schedule s = te.create_schedule(C.op) @@ -130,15 +130,16 @@ # fill 32 * 32 * sizeof(float) which is 4KB in the cache whose total size is 32KB (L1 data cache) bn = 32 +kfactor = 4 s = te.create_schedule(C.op) # Blocking by loop tiling -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) # Hoist reduction domain outside the blocking loop -s[C].reorder(xo, yo, ko, ki, xi, yi) +s[C].reorder(mo, no, ko, ki, mi, ni) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -162,19 +163,20 @@ # ------------- # Another important trick is vectorization. When the memory access pattern is uniform, # the compiler can detect this pattern and pass the continuous memory to vector processor. In TVM, -# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it vastly. +# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it +# vastly. # # In this tutorial, we chose to vectorize the inner loop row data since it is cache friendly. s = te.create_schedule(C.op) -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) -s[C].reorder(xo, yo, ko, ki, xi, yi) +s[C].reorder(mo, no, ko, ki, mi, ni) # Vectorization -s[C].vectorize(yi) +s[C].vectorize(ni) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -194,20 +196,19 @@ ################################################################################################### # Loop Permutation # ---------------- -# If we look at the above IR, we can see the inner loop row data is vectorized and -# B is transformed into PackedB. The traversal of PackedB is sequential now. -# So we will look at the access pattern of A. In current schedule, A is accessed column by column -# which is not cache friendly. If we change the nested loop order of ki and inner axes xi, +# If we look at the above IR, we can see the inner loop row data is vectorized for both B and C. +# Next we will look at the access pattern of A. In current schedule, A is accessed column by column +# which is not cache friendly. If we change the nested loop order of ki and inner axes mi, # the access pattern for A matrix is more cache friendly. s = te.create_schedule(C.op) -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) # re-ordering -s[C].reorder(xo, yo, ko, xi, ki, yi) -s[C].vectorize(yi) +s[C].reorder(mo, no, ko, mi, ki, ni) +s[C].vectorize(ni) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -227,43 +228,48 @@ ################################################################################################### # Array Packing # ------------- -# Another important trick is array packing. This trick is to reorder the storage dimension of the -# array to convert the continuous access pattern on certain dimension to a sequential pattern after -# flattening. +# Another important trick is array packing. The trick is to reorder the storage of a multi- +# dimensional array so that it is accessed sequentially after it is flattened and stored in one- +# dimensional memory. # # .. image:: https://github.com/dmlc/web-data/raw/main/tvm/tutorial/array-packing.png # :align: center # +# NOTE: This figure is a general illustration of how array packing works. ################################################################################################### -# Just as it is shown in the figure above, after blocking the computations, we can observe the array -# access pattern of B (after flattening), which is regular but discontinuous. We expect that after -# some transformation we can get continuous access pattern. We can reorder a [16][16] array to -# a [16/4][16][4] array, so that the access pattern of B will be sequential when grabing -# the corresponding value from the packed array. -# +# We can use array packing to address the access pattern for B. Observe the array access pattern of +# B after flattening which is not sequential as we iterate over the K dimension. We can reorder B +# with dimensions [K][N] so that it has dimensions [N/bn][K][bn] where bn is the blocking factor and +# also the vector size for B in the inner loop. This reorder splits N into two dimensions --- +# bigN (N/bn) and littleN (bn) --- and the new dimensions [N/bn][K][bn] match the indexing of B +# from outer to inner loops (no, ko, ki, ni) resulting in a sequential access pattern for B after +# flattening. + # We have to re-write the algorithm slightly. -packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name="packedB") +packedB = te.compute( + (N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB" +) C = te.compute( (M, N), - lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k), + lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k), name="C", ) s = te.create_schedule(C.op) -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -(k,) = s[C].op.reduce_axis -ko, ki = s[C].split(k, factor=4) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +(kaxis,) = s[C].op.reduce_axis +ko, ki = s[C].split(kaxis, factor=kfactor) -s[C].reorder(xo, yo, ko, xi, ki, yi) -s[C].vectorize(yi) +s[C].reorder(mo, no, ko, mi, ki, ni) +s[C].vectorize(ni) -x, y, z = s[packedB].op.axis -s[packedB].vectorize(z) -s[packedB].parallel(x) +bigN, _, littleN = s[packedB].op.axis +s[packedB].vectorize(littleN) +s[packedB].parallel(bigN) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -293,23 +299,28 @@ # Allocate write cache CC = s.cache_write(C, "global") -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -# Write cache is computed at yo -s[CC].compute_at(s[C], yo) +# Write cache is computed at no +s[CC].compute_at(s[C], no) # New inner axes -xc, yc = s[CC].op.axis +mc, nc = s[CC].op.axis + +(kaxis,) = s[CC].op.reduce_axis +ko, ki = s[CC].split(kaxis, factor=kfactor) +s[CC].reorder(ko, mc, ki, nc) +s[CC].vectorize(nc) -(k,) = s[CC].op.reduce_axis -ko, ki = s[CC].split(k, factor=4) -s[CC].reorder(ko, xc, ki, yc) +# TODO: Add separate optimization step to discuss loop unrolloing +# unrolling is a loop optimization strategy which can reduce branch +# prediction failures and increases the chance of concurrent execution +# unroll kfactor loops s[CC].unroll(ki) -s[CC].vectorize(yc) -x, y, z = s[packedB].op.axis -s[packedB].vectorize(z) -s[packedB].parallel(x) +bigN, _, littleN = s[packedB].op.axis +s[packedB].vectorize(littleN) +s[packedB].parallel(bigN) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func @@ -335,24 +346,24 @@ CC = s.cache_write(C, "global") -xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) +mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn) -s[CC].compute_at(s[C], yo) +s[CC].compute_at(s[C], no) -xc, yc = s[CC].op.axis +mc, nc = s[CC].op.axis -(k,) = s[CC].op.reduce_axis -ko, ki = s[CC].split(k, factor=4) -s[CC].reorder(ko, xc, ki, yc) +(kaxis,) = s[CC].op.reduce_axis +ko, ki = s[CC].split(kaxis, factor=kfactor) +s[CC].reorder(ko, mc, ki, nc) +s[CC].vectorize(nc) s[CC].unroll(ki) -s[CC].vectorize(yc) # parallel -s[C].parallel(xo) +s[C].parallel(mo) -x, y, z = s[packedB].op.axis -s[packedB].vectorize(z) -s[packedB].parallel(x) +bigN, _, littleN = s[packedB].op.axis +s[packedB].vectorize(littleN) +s[packedB].parallel(bigN) func = tvm.build(s, [A, B, C], target=target, name="mmult") assert func