Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor optimize GEMM on CPU tutorial #8825

Merged
merged 8 commits into from
Aug 31, 2021
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 67 additions & 54 deletions tutorials/optimize/opt_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")

# Default schedule
s = te.create_schedule(C.op)
Expand Down Expand Up @@ -130,15 +130,16 @@
# fill 32 * 32 * sizeof(float) which is 4KB in the cache whose total size is 32KB (L1 data cache)

bn = 32
kfactor = 4
s = te.create_schedule(C.op)

# Blocking by loop tiling
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k,) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)
ko, ki = s[C].split(k, factor=kfactor)

# Hoist reduction domain outside the blocking loop
s[C].reorder(xo, yo, ko, ki, xi, yi)
s[C].reorder(mo, no, ko, ki, mi, ni)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func
Expand All @@ -162,19 +163,20 @@
# -------------
# Another important trick is vectorization. When the memory access pattern is uniform,
# the compiler can detect this pattern and pass the continuous memory to vector processor. In TVM,
# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it vastly.
# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it
# vastly.
#
# In this tutorial, we chose to vectorize the inner loop row data since it is cache friendly.

s = te.create_schedule(C.op)
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k,) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)
ko, ki = s[C].split(k, factor=kfactor)

s[C].reorder(xo, yo, ko, ki, xi, yi)
s[C].reorder(mo, no, ko, ki, mi, ni)

# Vectorization
s[C].vectorize(yi)
s[C].vectorize(ni)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func
Expand All @@ -194,20 +196,19 @@
###################################################################################################
# Loop Permutation
# ----------------
# If we look at the above IR, we can see the inner loop row data is vectorized and
# B is transformed into PackedB. The traversal of PackedB is sequential now.
# So we will look at the access pattern of A. In current schedule, A is accessed column by column
# which is not cache friendly. If we change the nested loop order of ki and inner axes xi,
# If we look at the above IR, we can see the inner loop row data is vectorized for both B and C.
# Next we will look at the access pattern of A. In current schedule, A is accessed column by column
# which is not cache friendly. If we change the nested loop order of ki and inner axes mi,
# the access pattern for A matrix is more cache friendly.

s = te.create_schedule(C.op)
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k,) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)
ko, ki = s[C].split(k, factor=kfactor)

# re-ordering
s[C].reorder(xo, yo, ko, xi, ki, yi)
s[C].vectorize(yi)
s[C].reorder(mo, no, ko, mi, ki, ni)
s[C].vectorize(ni)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func
Expand All @@ -227,43 +228,49 @@
###################################################################################################
# Array Packing
# -------------
# Another important trick is array packing. This trick is to reorder the storage dimension of the
# array to convert the continuous access pattern on certain dimension to a sequential pattern after
# Another important trick is array packing. This trick is to reorder the storage dimension of an
# array to convert the continuous access pattern on its dimensions to a sequential pattern after
adstraw marked this conversation as resolved.
Show resolved Hide resolved
# flattening.
#
# .. image:: https://github.com/dmlc/web-data/raw/main/tvm/tutorial/array-packing.png
# :align: center
#
# NOTE: The figure above is meant for illustration purposes only. Please ignore dimension
adstraw marked this conversation as resolved.
Show resolved Hide resolved
# information ([16][16] and [16/4][16][4]) which does not apply to this tutorial.


###################################################################################################
# Just as it is shown in the figure above, after blocking the computations, we can observe the array
# access pattern of B (after flattening), which is regular but discontinuous. We expect that after
# some transformation we can get continuous access pattern. We can reorder a [16][16] array to
# a [16/4][16][4] array, so that the access pattern of B will be sequential when grabing
# the corresponding value from the packed array.
#
# We can use array packing to address the access pattern for B. Observe the array access pattern of
# B after flattening which is not sequential as we iterate over the K dimension. We can reorder B
# with dimensions [K][N] so that it has dimensions [N/bn][K][bn] where bn is the blocking factor and
# also the vector size for both B in the inner loop. This reorder splits N into two dimensions ---
# bigN (N/bn) and littleN (bn) --- and the new dimensions [N/bn][K][bn] match the indexing of B
# from outer to inner loops (no, ko, ki, ni) resulting in a sequential access pattern for B after
# flattening.


# We have to re-write the algorithm slightly.
packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name="packedB")
packedB = te.compute(
(N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB"
)
C = te.compute(
(M, N),
lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),
lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k),
name="C",
)

s = te.create_schedule(C.op)

xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
(k,) = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=4)
ko, ki = s[C].split(k, factor=kfactor)

s[C].reorder(xo, yo, ko, xi, ki, yi)
s[C].vectorize(yi)
s[C].reorder(mo, no, ko, mi, ki, ni)
s[C].vectorize(ni)

x, y, z = s[packedB].op.axis
s[packedB].vectorize(z)
s[packedB].parallel(x)
bigN, k, littleN = s[packedB].op.axis
adstraw marked this conversation as resolved.
Show resolved Hide resolved
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func
Expand Down Expand Up @@ -293,23 +300,26 @@
# Allocate write cache
CC = s.cache_write(C, "global")

xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

# Write cache is computed at yo
s[CC].compute_at(s[C], yo)
# Write cache is computed at no
s[CC].compute_at(s[C], no)

# New inner axes
xc, yc = s[CC].op.axis
mc, nc = s[CC].op.axis

(k,) = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, factor=4)
s[CC].reorder(ko, xc, ki, yc)
ko, ki = s[CC].split(k, factor=kfactor)
s[CC].reorder(ko, mc, ki, nc)
s[CC].vectorize(nc)

# unroll kfactor loops
# this is a separate optimization not discussed in this tutorial
adstraw marked this conversation as resolved.
Show resolved Hide resolved
s[CC].unroll(ki)
s[CC].vectorize(yc)

x, y, z = s[packedB].op.axis
s[packedB].vectorize(z)
s[packedB].parallel(x)
bigN, k, littleN = s[packedB].op.axis
adstraw marked this conversation as resolved.
Show resolved Hide resolved
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func
Expand All @@ -335,24 +345,27 @@

CC = s.cache_write(C, "global")

xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

s[CC].compute_at(s[C], yo)
s[CC].compute_at(s[C], no)

xc, yc = s[CC].op.axis
mc, nc = s[CC].op.axis

(k,) = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, factor=4)
s[CC].reorder(ko, xc, ki, yc)
ko, ki = s[CC].split(k, factor=kfactor)
s[CC].reorder(ko, mc, ki, nc)
s[CC].vectorize(nc)

# unroll kfactor loops
# this is a separate optimization not discussed in this tutorial
s[CC].unroll(ki)
s[CC].vectorize(yc)

# parallel
s[C].parallel(xo)
s[C].parallel(mo)

x, y, z = s[packedB].op.axis
s[packedB].vectorize(z)
s[packedB].parallel(x)
bigN, k, littleN = s[packedB].op.axis
adstraw marked this conversation as resolved.
Show resolved Hide resolved
s[packedB].vectorize(littleN)
s[packedB].parallel(bigN)

func = tvm.build(s, [A, B, C], target=target, name="mmult")
assert func
Expand Down