forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a cute tool to plot blocked, dotOperand, and mfma layout (#407)
* Add commands to plot blocked, dotOperand, and mfma layout * Add commands to plot LDS layout and wmma instruction layout
- Loading branch information
Showing
2 changed files
with
1,224 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,350 @@ | ||
import argparse | ||
import sys | ||
import yaml | ||
import os | ||
import glob | ||
import subprocess | ||
|
||
|
||
def draw_preamble_cmd(): | ||
return '''\\documentclass[tikz, border=1mm, dvipsnames]{standalone} | ||
\\usepackage{ifthen} | ||
\\usepackage{tikz} | ||
\\usetikzlibrary{arrows.meta,arrows} | ||
\\usetikzlibrary{intersections} | ||
\\usetikzlibrary{calc, quotes} | ||
\\usetikzlibrary{patterns} | ||
\\usepackage{xparse} | ||
\\ExplSyntaxOn | ||
\\NewExpandableDocumentCommand{\\bitwiseXor}{mm} | ||
{ | ||
\\recuenco_bitwise_xor:nn { #1 } { #2 } | ||
} | ||
\\cs_new:Nn \\recuenco_bitwise_xor:nn | ||
{ | ||
\\int_from_bin:e | ||
{ | ||
\\__recuenco_bitwise_xor:ee { \\int_to_bin:n { #1 } } { \\int_to_bin:n { #2 } } | ||
} | ||
} | ||
\\cs_generate_variant:Nn \\int_from_bin:n { e } | ||
\\cs_new:Nn \\__recuenco_bitwise_xor:nn | ||
{ | ||
\\__recuenco_bitwise_xor_binary:ee | ||
{ | ||
\\prg_replicate:nn | ||
{ | ||
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #1 } | ||
} | ||
{ 0 } | ||
#1 | ||
} | ||
{ | ||
\\prg_replicate:nn | ||
{ | ||
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #2 } | ||
} | ||
{ 0 } | ||
#2 | ||
} | ||
} | ||
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor:nn { ee } | ||
\\cs_new:Nn \\__recuenco_bitwise_xor_binary:nn | ||
{ | ||
\\__recuenco_bitwise_xor_binary:w #1;#2; | ||
} | ||
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor_binary:nn { ee } | ||
\\cs_new:Npn \\__recuenco_bitwise_xor_binary:w #1#2;#3#4; | ||
{ | ||
\\int_abs:n { #1-#3 } | ||
\\tl_if_empty:nF { #2 } { \\__recuenco_bitwise_xor_binary:w #2;#4; } | ||
} | ||
\\ExplSyntaxOff''' | ||
|
||
|
||
def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): | ||
return f'''\\begin{{document}} | ||
\\begin{{tikzpicture}} | ||
\\def\\scale{{1}} | ||
\\def\\elem{{0.04}} | ||
\\coordinate (C TL) at (0,0); | ||
\\def\\opColorAL{{magenta}} | ||
\\def\\opColorAR{{cyan}} | ||
\\def\\opColorBL{{Maroon}} | ||
\\def\\opColorBR{{BlueGreen}} | ||
\\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kpack}}} | ||
\\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); | ||
\\def\\mfmaTrans{{{trans}}} | ||
\\ifthenelse{{\\mfmaTrans=0}}{{ | ||
\\def\\opColorAL{{magenta}} | ||
\\def\\opColorAR{{cyan}} | ||
\\def\\opColorBL{{Maroon}} | ||
\\def\\opColorBR{{BlueGreen}} | ||
}}{{ | ||
\\def\\opColorBL{{magenta}} | ||
\\def\\opColorBR{{cyan}} | ||
\\def\\opColorAL{{Maroon}} | ||
\\def\\opColorAR{{BlueGreen}} | ||
}} | ||
%% Draw zoomed in view of mfma | ||
\\def\\elem{{.16}} | ||
\\pgfmathsetmacro{{\\gap}}{{\\elem*5}} | ||
\\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} | ||
\\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+2*{kpack}*\\elem, 0)$); | ||
\\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}} | ||
\\end{{tikzpicture}} | ||
\\end{{document}}''' | ||
|
||
|
||
def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, | ||
order): | ||
return f'''\\begin{{document}} | ||
\\begin{{tikzpicture}} | ||
\\def\\scale{{1}} | ||
\\def\\elem{{0.06}} | ||
\\coordinate (TL) at (0,0); | ||
\\drawBlockedTensor{{{M}}}{{{K}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}} | ||
\\end{{tikzpicture}} | ||
\\end{{document}}''' | ||
|
||
|
||
def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, | ||
threadsPerWarp): | ||
if ldsLayout == 'swizzle': | ||
hasSwizzle = 1 | ||
elif ldsLayout == 'padding': | ||
hasSwizzle = 2 | ||
else: | ||
hasSwizzle = 0 | ||
|
||
if ldsAccess == 'read': | ||
accessMode = 1 | ||
elif ldsAccess == 'write': | ||
accessMode = 2 | ||
else: | ||
accessMode = 0 | ||
|
||
return f'''\\begin{{document}} | ||
\\begin{{tikzpicture}} | ||
\\def\\scale{{1}} | ||
\\def\\M{{{M}}} | ||
\\def\\K{{{K}}} | ||
\\def\\vec{{{kpack}}} | ||
\\def\\hasSwizzle{{{hasSwizzle}}} | ||
\\def\\accessMode{{{accessMode}}} | ||
\\def\\sizePerThreadK{{{sizePerThread[1]}}} | ||
\\def\\sizePerThreadM{{{sizePerThread[0]}}} | ||
\\def\\threadsPerWarpK{{{threadsPerWarp[1]}}} | ||
\\def\\elem{{0.18}} | ||
\\coordinate (TL) at (0,0); | ||
\\drawTensorLayoutGlobalMem | ||
\\coordinate (TL) at ($(TL)+(0, -24*\\elem-10*\\elem)$); | ||
\\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}} | ||
\\end{{tikzpicture}} | ||
\\end{{document}}''' | ||
|
||
|
||
def draw_wmma_instr_cmd(waveSize): | ||
wmma_mode = 0 if waveSize == 32 else 1 | ||
return f'''\\begin{{document}} | ||
\\begin{{tikzpicture}} | ||
\\def\\scale{{1}} | ||
\\coordinate (C TL) at (0,0); | ||
\\def\\elem{{0.25}} | ||
\\drawWMMAInstr{{{wmma_mode}}}{{1}} | ||
\\end{{tikzpicture}} | ||
\\end{{document}}''' | ||
|
||
|
||
def run_bash_command(commandstring): | ||
proc = subprocess.run(commandstring, | ||
shell=True, | ||
check=True, | ||
executable='/bin/bash', | ||
stdout=subprocess.PIPE) | ||
return proc.stdout.splitlines() | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
prog="Draw triton layouts", | ||
allow_abbrev=False, | ||
) | ||
## tensor shapes | ||
parser.add_argument("-shape", | ||
type=int, | ||
nargs=3, | ||
default=(32, 128, 64), | ||
help='Tensor shape in the form of M,N,K') | ||
parser.add_argument("-plot", | ||
type=str, | ||
default="blocked", | ||
choices=['blocked', 'dot', 'wmma', 'lds'], | ||
help='choose plot mode') | ||
parser.add_argument( | ||
"-nonKDim", | ||
type=int, | ||
default=32, | ||
choices=[32], | ||
help='mfma instruction dim, only 32 is supported for now') | ||
## blocked layout parameters | ||
parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4)) | ||
parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4)) | ||
parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4)) | ||
parser.add_argument("-order", type=int, nargs=2, default=(1, 0)) | ||
## LDS access parameters | ||
parser.add_argument("-kpack", | ||
type=int, | ||
default=4, | ||
choices=[4, 8], | ||
help='vector length during LDS load, same as vec') | ||
parser.add_argument("-lds_layout", | ||
type=str, | ||
default="none", | ||
choices=['swizzle', 'padding', 'none'], | ||
help='choose the LDS data layout') | ||
parser.add_argument("-lds_access", | ||
type=str, | ||
default="none", | ||
choices=['read', 'write', 'none'], | ||
help='choose LDS access mode') | ||
## wmma instruction layout parameter | ||
parser.add_argument("-wave_size", | ||
type=int, | ||
default=32, | ||
choices=[32, 64], | ||
help='choose the wmma instruction mode') | ||
|
||
parser.add_argument("-o", | ||
type=str, | ||
default="myplot", | ||
help='output pdf file name (without surfix)') | ||
parser.add_argument("-mfmaTrans", | ||
action='store_true', | ||
default=False, | ||
help='If set, then use mfma.trans layout') | ||
parser.add_argument("--keep", | ||
action='store_true', | ||
default=False, | ||
help='If set, keep the generated .tex file') | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
shape = args.shape | ||
M = shape[0] | ||
N = shape[1] | ||
K = shape[2] | ||
plot_mode = args.plot | ||
mfmaNonKDim = args.nonKDim | ||
kpack = args.kpack | ||
trans = 1 if args.mfmaTrans else 0 | ||
ofilename = args.o | ||
keepSrc = args.keep | ||
|
||
ldsLayout = args.lds_layout | ||
ldsAccess = args.lds_access | ||
|
||
waveSize = args.wave_size | ||
|
||
sizePerThread = args.sizePerThread | ||
threadsPerWarp = args.threadsPerWarp | ||
warpsPerCTA = args.warpsPerCTA | ||
order = args.order | ||
|
||
CTAShape = [] | ||
if plot_mode == 'blocked': | ||
print(f"Plotting tensor M={M},K={K} with blocked layout:") | ||
print(f"sizePerThread={sizePerThread}", end=" ") | ||
print(f"threadsPerWarp={threadsPerWarp}", end=" ") | ||
print(f"warpsPerCTA={warpsPerCTA}", end=" ") | ||
print(f"order={order}", end=" ") | ||
CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0]) | ||
CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1]) | ||
|
||
if plot_mode == 'dot': | ||
mfma_inst_str = "mfma_32x32x8f16" if mfmaNonKDim == 32 else "mfma_16x16x16f16" | ||
mfma_trans_str = ".trans" if trans else "" | ||
print(f"Plotting dot operation with shapes M={M},N={N},K={K}") | ||
print("MFMA: " + mfma_inst_str + mfma_trans_str, end=" ") | ||
print(f"warpsPerCTA={warpsPerCTA}", end=" ") | ||
CTAShape.append(32 * warpsPerCTA[0]) | ||
CTAShape.append(32 * warpsPerCTA[1]) | ||
|
||
if plot_mode == 'blocked' or plot_mode == 'dot': | ||
print(f"CTAShape={CTAShape}") | ||
assert M != 0 and CTAShape[ | ||
0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" | ||
|
||
if plot_mode == 'blocked': | ||
assert K != 0 and CTAShape[ | ||
1] <= K and K % CTAShape[1] == 0, "bad tensor dimension K" | ||
|
||
if plot_mode == 'dot': | ||
assert N != 0 and CTAShape[ | ||
1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" | ||
assert K != 0 and K % (2 * kpack) == 0, "bad tensor dimension K" | ||
|
||
if plot_mode == 'lds': | ||
print(f"Plotting LDS access for tensor M={M},K={K} with vec={kpack}") | ||
if ldsAccess == 'write': | ||
print( | ||
f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}" | ||
) | ||
|
||
with open("myplot.tex", 'w') as f_plot: | ||
with open("tikzplot.tex") as file: | ||
tikz_code = file.read() | ||
|
||
preamble_str = draw_preamble_cmd() | ||
|
||
draw_blockedLayout_str = draw_blocked_layout_cmd( | ||
M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order) | ||
|
||
draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, | ||
warpsPerCTA, trans, kpack) | ||
|
||
draw_lds_str = draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, | ||
sizePerThread, threadsPerWarp) | ||
|
||
draw_wmma_str = draw_wmma_instr_cmd(waveSize) | ||
|
||
f_plot.write(preamble_str + "\n") | ||
f_plot.write(tikz_code) | ||
if plot_mode == 'blocked': | ||
f_plot.write(draw_blockedLayout_str) | ||
elif plot_mode == 'dot': | ||
f_plot.write(draw_dotLayout_str) | ||
elif plot_mode == 'lds': | ||
f_plot.write(draw_lds_str) | ||
elif plot_mode == 'wmma': | ||
f_plot.write(draw_wmma_str) | ||
|
||
run_bash_command(f"pdflatex -jobname {ofilename} myplot.tex") | ||
print(f"plot saved in {ofilename}.pdf") | ||
|
||
## Remove au files | ||
os.remove(f"{ofilename}.aux") | ||
os.remove(f"{ofilename}.log") | ||
if not keepSrc: | ||
os.remove("myplot.tex") | ||
run_bash_command("rm -rf ./auto") | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(main()) |
Oops, something went wrong.