Skip to content

Commit

Permalink
Interleave restructure (#58)
Browse files Browse the repository at this point in the history
* Package all the logic into one operation
* Add various tail strategies: recursive division and binary
specialization.
  • Loading branch information
SamirDroubi authored Jan 28, 2024
1 parent 83155ca commit aa1110d
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 19 deletions.
14 changes: 3 additions & 11 deletions src/common/blaslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,12 @@ def optimize_level_1(proc, loop, params):
if interleave_factor == 1:
return simplify(proc)

# Tile to exploit ILP
proc, (loop, inner_loop, _) = auto_divide_loop(
proc, loop, interleave_factor, tail="cut"
proc = interleave_loop(
proc, loop, interleave_factor, par_reduce=True, memory=mem_type
)

proc = parallelize_all_reductions(proc, loop, memory=mem_type, unroll=True)

# Intereleave to increase ILP
inner_loop = proc.forward(loop).body()[0]
proc = interleave_loop(proc, inner_loop)

# Instructions Selection
proc = replace_all_stmts(proc, instructions)
proc = cleanup(proc)
proc = replace_all_stmts(proc, instructions)
return proc


Expand Down
172 changes: 164 additions & 8 deletions src/common/composed_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def vectorize_stmt(proc, stmt, depth=1):
return proc


def interleave_loop(proc, loop, divide=None, tail="cut"):
def interleave_loop(proc, loop, factor=None, par_reduce=False, memory=DRAM, tail="cut"):
"""
for i in seq(0, c):
S1
Expand All @@ -439,15 +439,53 @@ def interleave_loop(proc, loop, divide=None, tail="cut"):

loop = proc.forward(loop)

if divide is not None:
proc, (_, loop, _) = auto_divide_loop(proc, loop, divide, tail=tail)
def rewrite(proc, loop, factor=None, par_reduce=False, memory=DRAM, tail="cut"):
loop = proc.forward(loop)
if factor is not None:
proc, (outer, loop, _) = auto_divide_loop(proc, loop, factor, tail=tail)
if par_reduce:
proc = parallelize_all_reductions(
proc, outer, memory=memory, unroll=True
)
loop = proc.forward(outer).body()[0]
else:
if par_reduce:
proc = parallelize_all_reductions(
proc, loop, memory=memory, unroll=True
)

allocs = filter(lambda s: isinstance(s, AllocCursor), loop.body())
proc = apply(parallelize_and_lift_alloc)(proc, allocs)
allocs = filter(lambda s: isinstance(s, AllocCursor), loop.body())
proc = apply(parallelize_and_lift_alloc)(proc, allocs)

stmts = list(proc.forward(loop).body())
proc = apply(fission)(proc, [s.after() for s in stmts[:-1]])
proc = apply(unroll_loop)(proc, [proc.forward(s).parent() for s in stmts])
stmts = list(proc.forward(loop).body())
proc = apply(fission)(proc, [s.after() for s in stmts[:-1]])
proc = apply(unroll_loop)(proc, [proc.forward(s).parent() for s in stmts])
return proc

if tail in {"cut", "cut_and_guard"}:
proc = rewrite(proc, loop, factor, par_reduce, memory, tail)
elif tail == "recursive":
if factor is None:
raise BLAS_SchedulingError(
"Cannot specify recursive tail strategy and factor=None"
)
proc, (_, inners, _) = divide_loop_recursive(
proc, loop, factor, tail="cut", rc=True
)
proc = apply(rewrite)(proc, inners, par_reduce=par_reduce, memory=memory)
elif tail == "specialize":
if factor is None:
raise BLAS_SchedulingError(
"Cannot specify recursive tail strategy and factor=None"
)
proc = rewrite(proc, loop, factor, par_reduce, memory, tail="cut")
tail_loop = proc.forward(loop).next()
proc, (stmts,) = binary_specialize(
proc, tail_loop, tail_loop.hi(), [i for i in range(factor)], rc=True
)
proc = apply(rewrite)(proc, stmts, par_reduce=par_reduce, memory=memory)
else:
raise BLAS_SchedulingError(f"Unknown tail strategy: {tail}")
return proc


Expand Down Expand Up @@ -1313,3 +1351,121 @@ def reorder_stmt_backwards(proc, stmt):
stmt = proc.forward(stmt)
block = stmt.as_block().expand(-1, 0)
return reorder_stmts(proc, block)


@dataclass
class divide_loop_recursive_cursors:
outer_loops: list
inner_loops: list
tail_loop: ForCursor

def __iter__(self):
yield self.outer_loops
yield self.inner_loops
yield self.tail_loop


def divide_loop_recursive(proc, loop, factor, tail="cut", rc=False):
if tail not in {"cut", "cut_and_guard"}:
raise BLAS_SchedulingError("tail strategy must be cut or cut_and_guard")
outer_loops = []
inner_loops = []
tail_loop = loop
while factor > 1:
proc, (outer, inner, tail_loop) = auto_divide_loop(
proc, tail_loop, factor, tail=tail
)
outer_loops.append(outer)
inner_loops.append(inner)
factor = factor // 2
if not rc:
return proc
outer_loops = [proc.forward(c) for c in outer_loops]
inner_loops = [proc.forward(c) for c in inner_loops]
return proc, divide_loop_recursive_cursors(outer_loops, inner_loops, tail_loop)


@dataclass
class specialize_cursors:
if_stmt: Cursor

def __iter__(self):
yield self.if_stmt


def specialize_(proc, stmt, cond, rc=False):
stmt = proc.forward(stmt)
parent = stmt.parent()
index = get_index_in_body(proc, stmt)
proc = specialize(proc, stmt, cond)
if not rc:
return proc
is_else = False
if (
isinstance(parent, IfCursor)
and index < len(parent.orelse())
and parent.orelse()[index] == stmt
):
is_else = True
if not isinstance(parent, InvalidCursor):
parent = proc.forward(parent)
else:
parent = proc

if_stmt = parent.body()[index] if not is_else else parent.orelse()[index]
return proc, specialize_cursors(if_stmt)


@dataclass
class binary_specialize_cursors:
stmts: Cursor

def __iter__(self):
yield self.stmts


def binary_specialize(proc, stmt, expr, values, rc=False):
stmt = proc.forward(stmt)
if isinstance(expr, ExprCursor):
expr = proc.forward(expr)
expr = expr_to_string(expr)
get_cond = lambda op, v: f"{expr} {op} {v}"

if len(values) == 1:
raise BLAS_SchedulingError("Cannot specialize given one value!")
values = sorted(values)
stmt = proc.forward(stmt)

stmts = []

def rewrite(proc, stmt, values):
if len(values) == 1:
# This should be redundant if the user provided correct inputs!
# So, it is really a check that the inputs the user provided cover the full range.
proc, (if_stmt,) = specialize_(
proc, stmt, get_cond("==", values[0]), rc=True
)
proc = simplify(proc)
proc = eliminate_dead_code(proc, if_stmt)
stmts.append(if_stmt.body()[0])
stmts.append(if_stmt.orelse()[0])
return proc
md = len(values) // 2
proc, (if_stmt,) = specialize_(proc, stmt, get_cond("<", values[md]), rc=True)
proc = rewrite(proc, if_stmt.body()[0], values[:md])
proc = rewrite(proc, if_stmt.orelse()[0], values[md:])
return proc

proc = rewrite(proc, stmt, values)
if not rc:
return proc

filtered_stmts = []
for s in stmts:
try:
stmt = proc.forward(s)
if not isinstance(stmt, PassCursor):
filtered_stmts.append(stmt)
except InvalidCursorError:
pass
return proc, binary_specialize_cursors(filtered_stmts)

0 comments on commit aa1110d

Please sign in to comment.