Skip to content

Commit

Permalink
syrk schedule (missing calls to gemm after tiling) (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Mar 26, 2024
1 parent 491ca3f commit 0cd3f23
Show file tree
Hide file tree
Showing 19 changed files with 452 additions and 1,853 deletions.
14 changes: 5 additions & 9 deletions src/common/codegen_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def generate_stride_1_proc(proc):
proc = rename(proc, proc.name() + "_stride_1")
for arg in proc.args():
if arg.is_tensor():
proc = proc.add_assertion(
f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1"
)
proc = proc.add_assertion(f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1")
return proc


Expand Down Expand Up @@ -110,15 +108,14 @@ def export_perf_features(kernel_name, perf_features):
json.dump(perf_features, f, sort_keys=True, indent=4, separators=(",", ": "))


def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon)):
def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon), stage_scalars=True):
def generate(proc, loop_name, *args, globals=None, **kwargs):
perf_features = {}
for precision in ("f32", "f64"):
proc_variant = specialize_precision(proc, precision)

proc_variant = stage_scalar_args(proc_variant)

stride_any = generate_stride_any_proc(proc_variant)
stride_any = stage_scalar_args(stride_any)
stride_any = bind_builtins_args(stride_any, stride_any.body(), precision)
export_exo_proc(globals, stride_any)

Expand All @@ -127,9 +124,8 @@ def generate(proc, loop_name, *args, globals=None, **kwargs):
algorithm = get_perf_features(stride_1)

if precision in opt_precisions and C.Machine.mem_type in targets:
stride_1 = blas_op(
stride_1, loop, precision, C.Machine, *args, **kwargs
)
stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args, **kwargs)
stride_1 = stage_scalar_args(stride_1)
stride_1 = bind_builtins_args(stride_1, stride_1.body(), precision)
scheduled = get_perf_features(stride_1)

Expand Down
28 changes: 11 additions & 17 deletions src/common/perf_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def get_stmt_ops(cursor, syms):
iter_sym = sp.Symbol(cursor.name(), integer=integers, positive=integers)
new_syms[cursor.name()] = iter_sym
stmts_ops = [get_stmt_ops(s, new_syms) for s in cursor.body()]
sums = [
sp.summation(s_ops, (iter_sym, lo, hi - one)) for s_ops in stmts_ops
]
sums = [sp.summation(s_ops, (iter_sym, lo, hi - one)) for s_ops in stmts_ops]
return sum(sums)
else:
return zero
Expand All @@ -74,14 +72,14 @@ def get_stmt_ops(cursor, syms):
for arg in proc.args():
if arg.type().is_indexable():
is_size = arg.type() == ExoType.Size
syms[arg.name()] = sp.Symbol(
arg.name(), integer=integers, positive=is_size and integers
)

ops = get_stmt_ops(proc.body(), syms)
ops = sp.simplify(ops)
ops = sp.factor(ops)
return ops
syms[arg.name()] = sp.Symbol(arg.name(), integer=integers, positive=is_size and integers)
try:
ops = get_stmt_ops(proc.body(), syms)
ops = sp.simplify(ops)
ops = sp.factor(ops)
return ops
except:
return zero


def count_flops(proc, upper=False):
Expand All @@ -91,9 +89,7 @@ def count_flops(proc, upper=False):
def get_expr_ops(proc, expr):
expr = proc.forward(expr)
if isinstance(expr, (BuiltInFunctionCursor, BinaryOpCursor, UnaryMinusCursor)):
children_ops = sum(
get_expr_ops(proc, c) for c in get_numeric_children(proc, expr)
)
children_ops = sum(get_expr_ops(proc, c) for c in get_numeric_children(proc, expr))
return one + children_ops
return zero

Expand Down Expand Up @@ -140,9 +136,7 @@ def get_bytes_traffic(p, c):
def get_expr_loads(proc, expr):
expr = proc.forward(expr)
if isinstance(expr, (BuiltInFunctionCursor, BinaryOpCursor, UnaryMinusCursor)):
return sum(
get_expr_loads(proc, c) for c in get_numeric_children(proc, expr)
)
return sum(get_expr_loads(proc, c) for c in get_numeric_children(proc, expr))
elif isinstance(expr, ReadCursor):
return get_bytes_traffic(proc, expr)
return zero
Expand Down
29 changes: 20 additions & 9 deletions src/common/stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ def get_depth(proc, scope):
def push_scope_in(proc, scope, depth, size=None):
scope = proc.forward(scope)
allocs = list(filter_cursors(is_alloc)(proc, scope.body()))
proc = apply(attempt(parallelize_and_lift_alloc))(proc, allocs)

proc = apply(parallelize_and_lift_alloc)(proc, allocs)
if const_allocs and size:
proc = apply(lambda p, a: resize_dim(p, a, 0, size, 0))(proc, allocs)

Expand Down Expand Up @@ -670,8 +669,9 @@ def get_window(cursor):
for loop in my_loops:
lo_rng = eval_rng(loop.lo(), env)
hi_rng = eval_rng(loop.hi(), env)
env[loop.name()] = (lo_rng[0], hi_rng[1])
window = tuple(eval_rng(idx, env) for idx in cursor.idx())
env[loop.name()] = (lo_rng[0], f"{hi_rng[1]} - 1")
extend_hi = lambda rng: (rng[0], f"{rng[1]} + 1") if rng[0] != rng[1] else rng
window = tuple(extend_hi(eval_rng(idx, env)) for idx in cursor.idx())
return window

def window_to_str(window):
Expand Down Expand Up @@ -1014,11 +1014,18 @@ def bound_loop_by_if(proc, loop):
if not isinstance(if_c.orelse(), InvalidCursor):
raise BLAS_SchedulingError(err)

if not isinstance(if_c.cond().lhs(), ReadCursor) or if_c.cond().lhs().name() != loop.name() or if_c.cond().op() != "<":
cond = if_c.cond()
cond_lhs = cond.lhs()

if is_read(proc, cond_lhs, loop.name()) and cond.op() == "<":
cut_point = FormattedExprStr("_ + _", loop.lo(), cond.rhs())
elif is_add(proc, cond_lhs) and is_read(proc, cond_lhs.lhs(), loop.name()) and cond.op() == "<":
cut_point = FormattedExprStr("_ + (_ - _)", loop.lo(), cond.rhs(), cond_lhs.rhs())
else:
raise BLAS_SchedulingError(err)

if_c = loop.body()[0]
proc = cut_loop(proc, loop, FormattedExprStr("_ + _", loop.lo(), if_c.cond().rhs()))
proc = cut_loop(proc, loop, cut_point)
loop1 = proc.forward(loop)
loop2 = loop1.next()
proc = eliminate_dead_code(proc, loop1.body()[0])
Expand Down Expand Up @@ -1333,9 +1340,7 @@ def add_loop_nest(loop):
for src_dim, _ in reversed(list(enumerate(alloc.shape()))):
for dst_dim, size in reversed(divisions[src_dim][1:]):
proc = divide_dim(proc, alloc, src_dim, size)
proc = apply(divide_loop_)(
proc, loop_nest[src_dim], size, tail="cut"
) # TODO: This should be enabled but it slows down compilation
proc = apply(divide_loop_)(proc, loop_nest[src_dim], size, tail="cut")
perm.append(dst_dim)
perm.append(divisions[src_dim][0][0])

Expand All @@ -1358,3 +1363,9 @@ def add_loop_nest(loop):
load = load.as_block().expand(0, diff)

return proc, pack_mem_cursors(alloc, load, block, store)


def inline_calls(proc, block=InvalidCursor(), subproc=None):
calls = filter_cursors(is_call)(proc, nlr_stmts(proc, block), subproc)
proc = simplify(apply(inline_proc_and_wins)(proc, calls))
return proc
3 changes: 1 addition & 2 deletions src/level3/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def schedule(main_gemm, i_loop, precision, machine, m_r, n_r_fac, M_tile, N_tile
gemm_tiled = apply(repeate_n(reorder_loops))(gemm_tiled, gemm_tiled.find_loop("ki", many=True), n=2)
gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro])

macro_calls = filter_cursors(is_call)(gemm_tiled, nlr_stmts(gemm_tiled))
gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls))
gemm_tiled = inline_calls(gemm_tiled, subproc=gemm_macro[1])

gemm_tiled = apply(hoist_from_loop)(gemm_tiled, gemm_tiled.find_loop("jo", many=True))
gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_A : _", many=True))
Expand Down
Loading

0 comments on commit 0cd3f23

Please sign in to comment.