Skip to content

Commit

Permalink
Rewrite optimize_level_2 (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Jan 29, 2024
1 parent aa1110d commit 149e7af
Show file tree
Hide file tree
Showing 6 changed files with 557 additions and 376 deletions.
52 changes: 9 additions & 43 deletions src/common/blaslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,50 +48,16 @@ def optimize_level_1(proc, loop, params):
return proc


def optimize_level_2(proc, params, reuse):
proc = generate_stride_1_proc(proc, params.precision)

# Taking a subspace of the 2D iteration dimension
proc, _ = auto_divide_loop(
proc, proc.find_loop("i"), params.rows_interleave_factor, tail="cut"
)

# Determine the tail strategy
vectorize_tail = params.mem_type in {AVX2}
tail = "guard" if vectorize_tail else "cut"

proc, _ = auto_divide_loop(proc, proc.find_loop("j"), params.vec_width, tail=tail)
proc = parallelize_all_reductions(
proc, proc.find_loop("jo"), memory=params.mem_type
)
proc = unroll_and_jam_parent(
proc, proc.find_loop("jo"), params.rows_interleave_factor, (True, False, True)
)

# Data reuse across rows
proc = simplify(auto_stage_mem(proc, proc.find(reuse), "shared", n_lifts=2))
proc = set_memory(proc, "shared", params.mem_type) # Simply to avoid a vector copy

# Generate SIMD
proc = scalar_to_simd(
proc,
proc.find_loop("ii").body()[0],
params.vec_width,
params.mem_type,
params.precision,
def optimize_level_2(proc, outer_loop, params, reuse):
rows_factor = params.interleave_factor
inner_loop = get_inner_loop(proc, outer_loop)
proc, (outer_loop_o, outer_loop_i, _) = auto_divide_loop(
proc, outer_loop, rows_factor, tail="cut"
)

# Interleave multiple rows dots
proc = interleave_loop(proc, proc.find_loop("ii"))

# Separate the tail case
if vectorize_tail:
loop = proc.find_loop("jo")
proc = cut_loop(proc, loop, FormattedExprStr("_ - 1", loop.hi()))
proc = dce(proc, loop)

# Instruction Selection
proc = replace_all_stmts(proc, params.instructions)
proc = unroll_and_jam(proc, outer_loop_i, rows_factor)
proc = unroll_buffers(proc, outer_loop_o)
proc = stage_mem(proc, proc.forward(inner_loop).body(), reuse, "tmp")
proc = optimize_level_1(proc, inner_loop, params)
return simplify(proc)


Expand Down
7 changes: 3 additions & 4 deletions src/common/composed_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,15 +712,16 @@ def parallelize_all_reductions(proc, loop, factor=None, memory=DRAM, unroll=Fals

def rewrite(proc, s):
s = proc.forward(s)
reduc_loop = proc.forward(loop)
nth_loop = 0
for parent in get_parents(proc, s):
if is_loop(proc, parent):
nth_loop += 1
if parent == loop:
if parent == reduc_loop:
break
return parallelize_reduction(proc, s, factor, memory, nth_loop, unroll)

return make_pass(attempt(rewrite))(proc, loop)
return make_pass(attempt(rewrite))(proc, loop.body())


def unroll_and_jam(proc, loop, factor, unroll=(True, True, True)):
Expand Down Expand Up @@ -890,9 +891,7 @@ def vectorize_predicate_tail(
tail="cut_and_predicate",
rc=False,
):

proc = parallelize_all_reductions(proc, loop, factor=vec_width, memory=mem_type)

allocs = filter(lambda s: isinstance(s, AllocCursor), nlr_stmts(proc, loop))
proc = apply(set_memory)(proc, allocs, mem_type)

Expand Down
10 changes: 5 additions & 5 deletions src/level2/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def gemv_row_major_Trans(
### EXO_LOC SCHEDULE START ###

template_sched_list = [
(optimize_level_2, gemv_row_major_NonTrans, "x[_]"),
(optimize_level_2, gemv_row_major_Trans, "y[_] += _"),
(optimize_level_2, gemv_row_major_NonTrans, "x[j]"),
(optimize_level_2, gemv_row_major_Trans, "y[j]"),
]

for precision in ("f32", "f64"):
Expand All @@ -64,12 +64,12 @@ def gemv_row_major_Trans(
params = Level_2_Params(
precision=precision,
rows_interleave_factor=4,
cols_interleave_factor=1,
accumulators_count=1,
)
export_exo_proc(globals(), proc_stride_any)
proc_stride_1 = generate_stride_1_proc(template, precision)
proc_stride_1 = sched(
template,
proc_stride_1,
proc_stride_1.find_loop("i"),
params,
reuse,
)
Expand Down
Loading

0 comments on commit 149e7af

Please sign in to comment.