Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement symm algorithms and testing #96

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions analytics_tools/graphing/kernels_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_loaded_bytes(self):
return (self.M * self.N + self.M * self.N / 4 + self.M) * get_elem_bytes(self.precision)

def get_stored_bytes(self):
if self.TransA == CBLAS_TRANSPOSE.CblasNoTrans:
if self.TransA == CBLAS_TRANSPOSE.CblasNoTrans.value:
return self.M * get_elem_bytes(self.precision)
else:
return ((self.M * self.N) // 4) * get_elem_bytes(self.precision)
Expand Down Expand Up @@ -322,18 +322,18 @@ def __init__(self, bench):

self.precision = run_dict["precision"]

self.M = int(run_dict.get("M", 0))
self.N = int(run_dict.get("N", 0))
self.K = int(run_dict.get("K", 0))
self.Order = int(run_dict.get("Order", 0))
self.Side = int(run_dict.get("Side", 0))
self.Uplo = int(run_dict.get("Uplo", 0))
self.TransA = int(run_dict.get("TransA", 0))
self.TransB = int(run_dict.get("TransB", 0))
self.Trans = int(run_dict.get("Trans", 0))

class gemm(level_3):
def __init__(self, bench):
super().__init__(bench)

run_name = bench["run_name"]
run_dict = run_name_to_dict(run_name)

self.M = int(run_dict["M"])
self.N = int(run_dict["N"])
self.K = int(run_dict["K"])

class gemm(level_3):
def get_size_param(self):
return self.K

Expand All @@ -351,22 +351,47 @@ def get_input_bytes(self):
return (self.M * self.K + self.K * self.N + self.M * self.N) * get_elem_bytes(self.precision)

def get_loaded_bytes(self):
return self.M * self.N * self.K * 2 * get_elem_bytes(self.precision)
return (self.get_flops() + self.M * self.N) * get_elem_bytes(self.precision)

def get_stored_bytes(self):
return self.M * self.N * get_elem_bytes(self.precision)


class syrk(level_3):
def __init__(self, bench):
super().__init__(bench)
class symm(level_3):
def get_size_param(self):
return self.N

run_name = bench["run_name"]
run_dict = run_name_to_dict(run_name)
def get_cmp_tuple_(self):
return (self.M, self.N)

self.N = int(run_dict["N"])
self.K = int(run_dict["K"])
def get_graph_description(self):
if self.bench_type == BENCH_TYPE.level_3_eq.value:
return "M = N"

def get_flops(self):
print(self.M)
print(self.N)
value = 2 * self.M * self.N
if self.Side == CBLAS_SIDE.CblasLeft.value:
return value * self.M
else:
return value * self.N

def get_input_bytes(self):
value = 2 * self.M * self.N
if self.Side == CBLAS_SIDE.CblasLeft.value:
return value + self.M**2
else:
return value + self.N**2

def get_loaded_bytes(self):
return (self.get_flops() + self.M * self.N) * get_elem_bytes(self.precision)

def get_stored_bytes(self):
return self.M * self.N * get_elem_bytes(self.precision)


class syrk(level_3):
def get_size_param(self):
return self.K

Expand Down
10 changes: 8 additions & 2 deletions analytics_tools/graphing/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ class BENCH_TYPE(Enum):

# From netlib `cblas.h`
class CBLAS_TRANSPOSE(Enum):
CblasNoTrans = (111,)
CblasTrans = (112,)
CblasNoTrans = 111
CblasTrans = 112
CblasConjTrans = 113


class CBLAS_SIDE(Enum):
CblasLeft = 141
CblasRight = 142


__all__ = [
"get_libfree_subkernel_name",
"get_elem_bytes",
Expand All @@ -57,4 +62,5 @@ class CBLAS_TRANSPOSE(Enum):
"level_2_bench_types",
"level_3_bench_types",
"CBLAS_TRANSPOSE",
"CBLAS_SIDE",
]
238 changes: 78 additions & 160 deletions src/level3/symm.py
Original file line number Diff line number Diff line change
@@ -1,167 +1,85 @@
from __future__ import annotations
from os import abort

import sys
import getopt
from exo import proc
from exo.platforms.x86 import *
from exo.platforms.neon import *
from exo import *
from exo.syntax import *
from exo.libs.memories import DRAM_STATIC

from exo.stdlib.scheduling import *

from kernels.gemm_kernels import GEBP_kernel, GEPP_kernel, Microkernel
from format_options import *
from exo.API_cursors import *
from exo.libs.memories import DRAM_STATIC

import exo_blas_config as C
from stdlib import *


class SYMM:
def __init__(
self,
machine: "MachineParameters",
precision: str,
K_blk: int,
M_blk: int,
N_blk: int,
M_r: int,
N_r: int,
do_rename=False,
main=True,
):

self.K_blk = K_blk
self.M_blk = M_blk
self.N_blk = N_blk
self.M_r = M_r
self.N_r = N_r
self.machine = machine
self.precision = precision
self.main = main

self.microkernel = Microkernel(machine, M_r, N_r, K_blk, precision)
self.gebp = GEBP_kernel(self.microkernel, M_blk, N_blk, precision)
self.gepp = GEPP_kernel(self.gebp, precision)

### Base Procedures

@proc
def symm_lower_left_noalpha_nobeta(
M: size, N: size, K: size, A: f32[M, K], B: f32[K, N], C: f32[M, N]
):
# This is a brute force method that just does GEMM. Let's see if this is okay performance wise
for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]

symm_lower_left_noalpha_nobeta = self.specialize_symm(
symm_lower_left_noalpha_nobeta, self.precision, ["A", "B", "C"]
)
scheduled_symm = self.schedule_symm_lower_noalpha(
symm_lower_left_noalpha_nobeta
)

self.entry_points = [scheduled_symm]

if do_rename:
for i in range(len(self.entry_points)):
self.entry_points[i] = rename(
self.entry_points[i],
f"{self.entry_points[i].name()}_{N_blk}_{M_blk}_{K_blk}",
)

def schedule_symm_lower_noalpha(self, symm):

symm = divide_loop(symm, "for k in _:_", self.K_blk, ["ko", "ki"], tail="cut")
symm = autofission(symm, symm.find("for ko in _:_ #0").after(), n_lifts=2)
symm = reorder_loops(symm, "j ko")
symm = reorder_loops(symm, "i ko")

symm = divide_loop(
symm, "for j in _:_ #0", self.N_blk, ["jo", "ji"], tail="cut"
)
symm = autofission(symm, symm.find("for jo in _:_ #0").after(), n_lifts=2)
symm = reorder_loops(symm, "i jo")
symm = reorder_loops(symm, "ko jo")

symm = stage_mem(
symm,
"for i in _:_ #0",
f"B[{self.microkernel.K_blk}*ko:{self.microkernel.K_blk}*ko+{self.microkernel.K_blk}, {self.gebp.N_blk}*jo:{self.gebp.N_blk}*jo+{self.gebp.N_blk}]",
"B_strip",
)

symm = replace_all_stmts(symm, self.gepp.gepp_base)
call_c = symm.find(f"gepp_base_{self.gepp.this_id}(_)")
symm = call_eqv(
symm, f"gepp_base_{self.gepp.this_id}(_)", self.gepp.gepp_scheduled
)

symm = call_eqv(
symm,
call_c,
self.gepp.gepp_scheduled,
)
symm = inline(symm, call_c)
symm = inline_window(symm, "C = C[_]")
symm = inline_window(symm, f"A = A[_]")
symm = inline_window(symm, "B = B_strip[_]")
symm = simplify(symm)

while True:
try:
symm = lift_alloc(symm, "B_reg_strip:_")
except:
break
while True:
try:
symm = lift_alloc(symm, "B_strip:_")
except:
break
symm = set_memory(symm, "B_strip:_", DRAM_STATIC)
symm = set_memory(symm, "B_reg_strip:_", DRAM_STATIC)

return simplify(symm)

def specialize_symm(self, symm, precision, args):
prefix = "s" if precision == "f32" else "d"
name = symm.name().replace("exo_", "")
specialized = rename(symm, "exo_" + prefix + name)

for arg in args:
specialized = set_precision(specialized, arg, precision)

if self.main:
specialized = rename(specialized, specialized.name() + "_main")

return specialized


k_blk = [48, 48 * 2, 48 * 4, 48 * 8, 480, 480]
m_blk = [48, 48 * 2, 48 * 4, 48 * 8, 240, 240]
n_blk = [48, 48 * 2, 48 * 4, 48 * 8, 480, 960]
m_reg = 6
n_reg = 16


ssymm_kernels = [
SYMM(C.Machine, "f32", k, m, n, m_reg, n_reg, True, False)
for (k, m, n) in zip(k_blk, m_blk, n_blk)
]

exo_ssymm_lower_left_noalpha_nobeta_48_48_48 = ssymm_kernels[0].entry_points[0]
exo_ssymm_lower_left_noalpha_nobeta_96_96_96 = ssymm_kernels[1].entry_points[0]
exo_ssymm_lower_left_noalpha_nobeta_192_192_192 = ssymm_kernels[2].entry_points[0]
exo_ssymm_lower_left_noalpha_nobeta_384_384_384 = ssymm_kernels[3].entry_points[0]
exo_ssymm_lower_left_noalpha_nobeta_480_240_480 = ssymm_kernels[4].entry_points[0]
exo_ssymm_lower_left_noalpha_nobeta_960_240_480 = ssymm_kernels[5].entry_points[0]

ssymm_kernel_names = []
for s in ssymm_kernels:
ssymm_kernel_names.extend(s.entry_points)

__all__ = [p.name() for p in ssymm_kernel_names]
from codegen_helpers import *
from blaslib import *


@proc
def symm_rm_ll(M: size, N: size, alpha: R, A: [R][M, M], B: [R][M, N], C: [R][M, N]):
assert stride(A, 1) == 1
assert stride(B, 1) == 1
assert stride(C, 1) == 1

for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, M):
a_val: R
if k < i + 1:
a_val = A[i, k]
else:
a_val = A[k, i]
C[i, j] += alpha * (a_val * B[k, j])


@proc
def symm_rm_lu(M: size, N: size, alpha: R, A: [R][M, M], B: [R][M, N], C: [R][M, N]):
assert stride(A, 1) == 1
assert stride(B, 1) == 1
assert stride(C, 1) == 1

for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, M):
a_val: R
if k < i + 1:
a_val = A[k, i]
else:
a_val = A[i, k]
C[i, j] += alpha * (a_val * B[k, j])


@proc
def symm_rm_rl(M: size, N: size, alpha: R, A: [R][N, N], B: [R][M, N], C: [R][M, N]):
assert stride(A, 1) == 1
assert stride(B, 1) == 1
assert stride(C, 1) == 1

for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, N):
a_val: R
if j < k + 1:
a_val = A[k, j]
else:
a_val = A[j, k]
C[i, j] += alpha * (B[i, k] * a_val)


@proc
def symm_rm_ru(M: size, N: size, alpha: R, A: [R][N, N], B: [R][M, N], C: [R][M, N]):
assert stride(A, 1) == 1
assert stride(B, 1) == 1
assert stride(C, 1) == 1

for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, N):
a_val: R
if j < k + 1:
a_val = A[j, k]
else:
a_val = A[k, j]
C[i, j] += alpha * (B[i, k] * a_val)


variants_generator(identity_schedule)(symm_rm_ll, "i", globals=globals())
variants_generator(identity_schedule)(symm_rm_lu, "i", globals=globals())
variants_generator(identity_schedule)(symm_rm_rl, "i", globals=globals())
variants_generator(identity_schedule)(symm_rm_ru, "i", globals=globals())
2 changes: 1 addition & 1 deletion test/codegen/reference/sha256/avx2.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"exo_sbmv": "5fbcbe46cd54adf6c653fa6305281f9b70b7fb61425e6b2d3d981cf08182c207",
"exo_scal": "1b907db2aa7ae945c8206f956d2d228c06ec9e19b8cced98d402d063e5feafd3",
"exo_swap": "908f3bb8eda4065dbaef9ff0c085a9d8b55eb12d893cc53212dfc905f24d290b",
"exo_symm": "f8192433e2e64b600a573ecd2a77a404f79bb0886316c1e483530fff5bfe0cf4",
"exo_symm": "d76302bbf8f0aa7a7740540cd138bb28da2cfe5c5f33e92f8331811e687f9077",
"exo_symv": "f96e0661b0221c69b43d9e476c3f5ba9a3ee41168149232b6a414145dc0d77f5",
"exo_syr": "3e6894a8a9003ede58e06c57b261a1638fd6249b8a80ed356e147862f06f39aa",
"exo_syr2": "9285dc796c9c573cbd974f419563a0c2e3d7507aba3521ee1c89e9e707913332",
Expand Down
Loading
Loading