Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

syrk schedule (missing calls to gemm after tiling) #94

Merged
merged 9 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/common/codegen_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def generate_stride_1_proc(proc):
proc = rename(proc, proc.name() + "_stride_1")
for arg in proc.args():
if arg.is_tensor():
proc = proc.add_assertion(
f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1"
)
proc = proc.add_assertion(f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1")
return proc


Expand Down Expand Up @@ -110,15 +108,14 @@ def export_perf_features(kernel_name, perf_features):
json.dump(perf_features, f, sort_keys=True, indent=4, separators=(",", ": "))


def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon)):
def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon), stage_scalars=True):
def generate(proc, loop_name, *args, globals=None, **kwargs):
perf_features = {}
for precision in ("f32", "f64"):
proc_variant = specialize_precision(proc, precision)

proc_variant = stage_scalar_args(proc_variant)

stride_any = generate_stride_any_proc(proc_variant)
stride_any = stage_scalar_args(stride_any)
stride_any = bind_builtins_args(stride_any, stride_any.body(), precision)
export_exo_proc(globals, stride_any)

Expand All @@ -127,9 +124,8 @@ def generate(proc, loop_name, *args, globals=None, **kwargs):
algorithm = get_perf_features(stride_1)

if precision in opt_precisions and C.Machine.mem_type in targets:
stride_1 = blas_op(
stride_1, loop, precision, C.Machine, *args, **kwargs
)
stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args, **kwargs)
stride_1 = stage_scalar_args(stride_1)
stride_1 = bind_builtins_args(stride_1, stride_1.body(), precision)
scheduled = get_perf_features(stride_1)

Expand Down
28 changes: 11 additions & 17 deletions src/common/perf_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def get_stmt_ops(cursor, syms):
iter_sym = sp.Symbol(cursor.name(), integer=integers, positive=integers)
new_syms[cursor.name()] = iter_sym
stmts_ops = [get_stmt_ops(s, new_syms) for s in cursor.body()]
sums = [
sp.summation(s_ops, (iter_sym, lo, hi - one)) for s_ops in stmts_ops
]
sums = [sp.summation(s_ops, (iter_sym, lo, hi - one)) for s_ops in stmts_ops]
return sum(sums)
else:
return zero
Expand All @@ -74,14 +72,14 @@ def get_stmt_ops(cursor, syms):
for arg in proc.args():
if arg.type().is_indexable():
is_size = arg.type() == ExoType.Size
syms[arg.name()] = sp.Symbol(
arg.name(), integer=integers, positive=is_size and integers
)

ops = get_stmt_ops(proc.body(), syms)
ops = sp.simplify(ops)
ops = sp.factor(ops)
return ops
syms[arg.name()] = sp.Symbol(arg.name(), integer=integers, positive=is_size and integers)
try:
ops = get_stmt_ops(proc.body(), syms)
ops = sp.simplify(ops)
ops = sp.factor(ops)
return ops
except:
return zero


def count_flops(proc, upper=False):
Expand All @@ -91,9 +89,7 @@ def count_flops(proc, upper=False):
def get_expr_ops(proc, expr):
expr = proc.forward(expr)
if isinstance(expr, (BuiltInFunctionCursor, BinaryOpCursor, UnaryMinusCursor)):
children_ops = sum(
get_expr_ops(proc, c) for c in get_numeric_children(proc, expr)
)
children_ops = sum(get_expr_ops(proc, c) for c in get_numeric_children(proc, expr))
return one + children_ops
return zero

Expand Down Expand Up @@ -140,9 +136,7 @@ def get_bytes_traffic(p, c):
def get_expr_loads(proc, expr):
expr = proc.forward(expr)
if isinstance(expr, (BuiltInFunctionCursor, BinaryOpCursor, UnaryMinusCursor)):
return sum(
get_expr_loads(proc, c) for c in get_numeric_children(proc, expr)
)
return sum(get_expr_loads(proc, c) for c in get_numeric_children(proc, expr))
elif isinstance(expr, ReadCursor):
return get_bytes_traffic(proc, expr)
return zero
Expand Down
29 changes: 20 additions & 9 deletions src/common/stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ def get_depth(proc, scope):
def push_scope_in(proc, scope, depth, size=None):
scope = proc.forward(scope)
allocs = list(filter_cursors(is_alloc)(proc, scope.body()))
proc = apply(attempt(parallelize_and_lift_alloc))(proc, allocs)

proc = apply(parallelize_and_lift_alloc)(proc, allocs)
if const_allocs and size:
proc = apply(lambda p, a: resize_dim(p, a, 0, size, 0))(proc, allocs)

Expand Down Expand Up @@ -670,8 +669,9 @@ def get_window(cursor):
for loop in my_loops:
lo_rng = eval_rng(loop.lo(), env)
hi_rng = eval_rng(loop.hi(), env)
env[loop.name()] = (lo_rng[0], hi_rng[1])
window = tuple(eval_rng(idx, env) for idx in cursor.idx())
env[loop.name()] = (lo_rng[0], f"{hi_rng[1]} - 1")
extend_hi = lambda rng: (rng[0], f"{rng[1]} + 1") if rng[0] != rng[1] else rng
window = tuple(extend_hi(eval_rng(idx, env)) for idx in cursor.idx())
return window

def window_to_str(window):
Expand Down Expand Up @@ -1014,11 +1014,18 @@ def bound_loop_by_if(proc, loop):
if not isinstance(if_c.orelse(), InvalidCursor):
raise BLAS_SchedulingError(err)

if not isinstance(if_c.cond().lhs(), ReadCursor) or if_c.cond().lhs().name() != loop.name() or if_c.cond().op() != "<":
cond = if_c.cond()
cond_lhs = cond.lhs()

if is_read(proc, cond_lhs, loop.name()) and cond.op() == "<":
cut_point = FormattedExprStr("_ + _", loop.lo(), cond.rhs())
elif is_add(proc, cond_lhs) and is_read(proc, cond_lhs.lhs(), loop.name()) and cond.op() == "<":
cut_point = FormattedExprStr("_ + (_ - _)", loop.lo(), cond.rhs(), cond_lhs.rhs())
else:
raise BLAS_SchedulingError(err)

if_c = loop.body()[0]
proc = cut_loop(proc, loop, FormattedExprStr("_ + _", loop.lo(), if_c.cond().rhs()))
proc = cut_loop(proc, loop, cut_point)
loop1 = proc.forward(loop)
loop2 = loop1.next()
proc = eliminate_dead_code(proc, loop1.body()[0])
Expand Down Expand Up @@ -1333,9 +1340,7 @@ def add_loop_nest(loop):
for src_dim, _ in reversed(list(enumerate(alloc.shape()))):
for dst_dim, size in reversed(divisions[src_dim][1:]):
proc = divide_dim(proc, alloc, src_dim, size)
proc = apply(divide_loop_)(
proc, loop_nest[src_dim], size, tail="cut"
) # TODO: This should be enabled but it slows down compilation
proc = apply(divide_loop_)(proc, loop_nest[src_dim], size, tail="cut")
perm.append(dst_dim)
perm.append(divisions[src_dim][0][0])

Expand All @@ -1358,3 +1363,9 @@ def add_loop_nest(loop):
load = load.as_block().expand(0, diff)

return proc, pack_mem_cursors(alloc, load, block, store)


def inline_calls(proc, block=InvalidCursor(), subproc=None):
calls = filter_cursors(is_call)(proc, nlr_stmts(proc, block), subproc)
proc = simplify(apply(inline_proc_and_wins)(proc, calls))
return proc
3 changes: 1 addition & 2 deletions src/level3/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def schedule(main_gemm, i_loop, precision, machine, m_r, n_r_fac, M_tile, N_tile
gemm_tiled = apply(repeate_n(reorder_loops))(gemm_tiled, gemm_tiled.find_loop("ki", many=True), n=2)
gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro])

macro_calls = filter_cursors(is_call)(gemm_tiled, nlr_stmts(gemm_tiled))
gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls))
gemm_tiled = inline_calls(gemm_tiled, subproc=gemm_macro[1])

gemm_tiled = apply(hoist_from_loop)(gemm_tiled, gemm_tiled.find_loop("jo", many=True))
gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_A : _", many=True))
Expand Down
Loading
Loading