Skip to content

Commit

Permalink
Add a cute tool to plot blocked, dotOperand, and mfma layout (#407)
Browse files Browse the repository at this point in the history
* Add commands to plot blocked, dotOperand, and mfma layout

* Add commands to plot LDS layout and wmma instruction layout
  • Loading branch information
zhanglx13 authored Nov 29, 2023
1 parent d9219e0 commit 670ae80
Show file tree
Hide file tree
Showing 2 changed files with 1,224 additions and 0 deletions.
350 changes: 350 additions & 0 deletions scripts/amd/plot_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
import argparse
import sys
import yaml
import os
import glob
import subprocess


def draw_preamble_cmd():
return '''\\documentclass[tikz, border=1mm, dvipsnames]{standalone}
\\usepackage{ifthen}
\\usepackage{tikz}
\\usetikzlibrary{arrows.meta,arrows}
\\usetikzlibrary{intersections}
\\usetikzlibrary{calc, quotes}
\\usetikzlibrary{patterns}
\\usepackage{xparse}
\\ExplSyntaxOn
\\NewExpandableDocumentCommand{\\bitwiseXor}{mm}
{
\\recuenco_bitwise_xor:nn { #1 } { #2 }
}
\\cs_new:Nn \\recuenco_bitwise_xor:nn
{
\\int_from_bin:e
{
\\__recuenco_bitwise_xor:ee { \\int_to_bin:n { #1 } } { \\int_to_bin:n { #2 } }
}
}
\\cs_generate_variant:Nn \\int_from_bin:n { e }
\\cs_new:Nn \\__recuenco_bitwise_xor:nn
{
\\__recuenco_bitwise_xor_binary:ee
{
\\prg_replicate:nn
{
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #1 }
}
{ 0 }
#1
}
{
\\prg_replicate:nn
{
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #2 }
}
{ 0 }
#2
}
}
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor:nn { ee }
\\cs_new:Nn \\__recuenco_bitwise_xor_binary:nn
{
\\__recuenco_bitwise_xor_binary:w #1;#2;
}
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor_binary:nn { ee }
\\cs_new:Npn \\__recuenco_bitwise_xor_binary:w #1#2;#3#4;
{
\\int_abs:n { #1-#3 }
\\tl_if_empty:nF { #2 } { \\__recuenco_bitwise_xor_binary:w #2;#4; }
}
\\ExplSyntaxOff'''


def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack):
return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\def\\elem{{0.04}}
\\coordinate (C TL) at (0,0);
\\def\\opColorAL{{magenta}}
\\def\\opColorAR{{cyan}}
\\def\\opColorBL{{Maroon}}
\\def\\opColorBR{{BlueGreen}}
\\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kpack}}}
\\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$);
\\def\\mfmaTrans{{{trans}}}
\\ifthenelse{{\\mfmaTrans=0}}{{
\\def\\opColorAL{{magenta}}
\\def\\opColorAR{{cyan}}
\\def\\opColorBL{{Maroon}}
\\def\\opColorBR{{BlueGreen}}
}}{{
\\def\\opColorBL{{magenta}}
\\def\\opColorBR{{cyan}}
\\def\\opColorAL{{Maroon}}
\\def\\opColorAR{{BlueGreen}}
}}
%% Draw zoomed in view of mfma
\\def\\elem{{.16}}
\\pgfmathsetmacro{{\\gap}}{{\\elem*5}}
\\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}}
\\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+2*{kpack}*\\elem, 0)$);
\\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}}
\\end{{tikzpicture}}
\\end{{document}}'''


def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA,
order):
return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\def\\elem{{0.06}}
\\coordinate (TL) at (0,0);
\\drawBlockedTensor{{{M}}}{{{K}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}}
\\end{{tikzpicture}}
\\end{{document}}'''


def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread,
threadsPerWarp):
if ldsLayout == 'swizzle':
hasSwizzle = 1
elif ldsLayout == 'padding':
hasSwizzle = 2
else:
hasSwizzle = 0

if ldsAccess == 'read':
accessMode = 1
elif ldsAccess == 'write':
accessMode = 2
else:
accessMode = 0

return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\def\\M{{{M}}}
\\def\\K{{{K}}}
\\def\\vec{{{kpack}}}
\\def\\hasSwizzle{{{hasSwizzle}}}
\\def\\accessMode{{{accessMode}}}
\\def\\sizePerThreadK{{{sizePerThread[1]}}}
\\def\\sizePerThreadM{{{sizePerThread[0]}}}
\\def\\threadsPerWarpK{{{threadsPerWarp[1]}}}
\\def\\elem{{0.18}}
\\coordinate (TL) at (0,0);
\\drawTensorLayoutGlobalMem
\\coordinate (TL) at ($(TL)+(0, -24*\\elem-10*\\elem)$);
\\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}}
\\end{{tikzpicture}}
\\end{{document}}'''


def draw_wmma_instr_cmd(waveSize):
wmma_mode = 0 if waveSize == 32 else 1
return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\coordinate (C TL) at (0,0);
\\def\\elem{{0.25}}
\\drawWMMAInstr{{{wmma_mode}}}{{1}}
\\end{{tikzpicture}}
\\end{{document}}'''


def run_bash_command(commandstring):
proc = subprocess.run(commandstring,
shell=True,
check=True,
executable='/bin/bash',
stdout=subprocess.PIPE)
return proc.stdout.splitlines()


def parse_args():
parser = argparse.ArgumentParser(
prog="Draw triton layouts",
allow_abbrev=False,
)
## tensor shapes
parser.add_argument("-shape",
type=int,
nargs=3,
default=(32, 128, 64),
help='Tensor shape in the form of M,N,K')
parser.add_argument("-plot",
type=str,
default="blocked",
choices=['blocked', 'dot', 'wmma', 'lds'],
help='choose plot mode')
parser.add_argument(
"-nonKDim",
type=int,
default=32,
choices=[32],
help='mfma instruction dim, only 32 is supported for now')
## blocked layout parameters
parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4))
parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4))
parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4))
parser.add_argument("-order", type=int, nargs=2, default=(1, 0))
## LDS access parameters
parser.add_argument("-kpack",
type=int,
default=4,
choices=[4, 8],
help='vector length during LDS load, same as vec')
parser.add_argument("-lds_layout",
type=str,
default="none",
choices=['swizzle', 'padding', 'none'],
help='choose the LDS data layout')
parser.add_argument("-lds_access",
type=str,
default="none",
choices=['read', 'write', 'none'],
help='choose LDS access mode')
## wmma instruction layout parameter
parser.add_argument("-wave_size",
type=int,
default=32,
choices=[32, 64],
help='choose the wmma instruction mode')

parser.add_argument("-o",
type=str,
default="myplot",
help='output pdf file name (without surfix)')
parser.add_argument("-mfmaTrans",
action='store_true',
default=False,
help='If set, then use mfma.trans layout')
parser.add_argument("--keep",
action='store_true',
default=False,
help='If set, keep the generated .tex file')

args = parser.parse_args()

return args


def main():
args = parse_args()

shape = args.shape
M = shape[0]
N = shape[1]
K = shape[2]
plot_mode = args.plot
mfmaNonKDim = args.nonKDim
kpack = args.kpack
trans = 1 if args.mfmaTrans else 0
ofilename = args.o
keepSrc = args.keep

ldsLayout = args.lds_layout
ldsAccess = args.lds_access

waveSize = args.wave_size

sizePerThread = args.sizePerThread
threadsPerWarp = args.threadsPerWarp
warpsPerCTA = args.warpsPerCTA
order = args.order

CTAShape = []
if plot_mode == 'blocked':
print(f"Plotting tensor M={M},K={K} with blocked layout:")
print(f"sizePerThread={sizePerThread}", end=" ")
print(f"threadsPerWarp={threadsPerWarp}", end=" ")
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
print(f"order={order}", end=" ")
CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0])
CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1])

if plot_mode == 'dot':
mfma_inst_str = "mfma_32x32x8f16" if mfmaNonKDim == 32 else "mfma_16x16x16f16"
mfma_trans_str = ".trans" if trans else ""
print(f"Plotting dot operation with shapes M={M},N={N},K={K}")
print("MFMA: " + mfma_inst_str + mfma_trans_str, end=" ")
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
CTAShape.append(32 * warpsPerCTA[0])
CTAShape.append(32 * warpsPerCTA[1])

if plot_mode == 'blocked' or plot_mode == 'dot':
print(f"CTAShape={CTAShape}")
assert M != 0 and CTAShape[
0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M"

if plot_mode == 'blocked':
assert K != 0 and CTAShape[
1] <= K and K % CTAShape[1] == 0, "bad tensor dimension K"

if plot_mode == 'dot':
assert N != 0 and CTAShape[
1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N"
assert K != 0 and K % (2 * kpack) == 0, "bad tensor dimension K"

if plot_mode == 'lds':
print(f"Plotting LDS access for tensor M={M},K={K} with vec={kpack}")
if ldsAccess == 'write':
print(
f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}"
)

with open("myplot.tex", 'w') as f_plot:
with open("tikzplot.tex") as file:
tikz_code = file.read()

preamble_str = draw_preamble_cmd()

draw_blockedLayout_str = draw_blocked_layout_cmd(
M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order)

draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim,
warpsPerCTA, trans, kpack)

draw_lds_str = draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess,
sizePerThread, threadsPerWarp)

draw_wmma_str = draw_wmma_instr_cmd(waveSize)

f_plot.write(preamble_str + "\n")
f_plot.write(tikz_code)
if plot_mode == 'blocked':
f_plot.write(draw_blockedLayout_str)
elif plot_mode == 'dot':
f_plot.write(draw_dotLayout_str)
elif plot_mode == 'lds':
f_plot.write(draw_lds_str)
elif plot_mode == 'wmma':
f_plot.write(draw_wmma_str)

run_bash_command(f"pdflatex -jobname {ofilename} myplot.tex")
print(f"plot saved in {ofilename}.pdf")

## Remove au files
os.remove(f"{ofilename}.aux")
os.remove(f"{ofilename}.log")
if not keepSrc:
os.remove("myplot.tex")
run_bash_command("rm -rf ./auto")


if __name__ == '__main__':
sys.exit(main())
Loading

0 comments on commit 670ae80

Please sign in to comment.