Skip to content

Commit

Permalink
Deprecate CSE argument from bind_expr (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Dec 19, 2023
1 parent e146459 commit 060c5e5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
8 changes: 3 additions & 5 deletions src/common/composed_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def stage_expr(proc, expr_cursors, new_name, precision="R", memory=DRAM, n_lifts
expr_cursors = [proc.forward(c) for c in expr_cursors]
enclosing_loop = get_enclosing_loop(expr_cursors[0])
stmt = get_statement(expr_cursors[0])
proc = bind_expr(proc, expr_cursors, new_name, cse=True)
proc = bind_expr(proc, expr_cursors, new_name)
stmt = proc.forward(stmt)
bind_stmt = stmt.prev()
alloc_stmt = bind_stmt.prev()
Expand Down Expand Up @@ -224,9 +224,7 @@ def vectorize_to_loops(proc, loop_cursor, vec_width, memory_type, precision):
"""

if not isinstance(loop_cursor, ForCursor):
raise BLAS_SchedulingError(
"vectorize_to_loops loop_cursor must be a ForCursor"
)
raise BLAS_SchedulingError("vectorize_to_loops loop_cursor must be a ForCursor")

loop_cursor = proc.forward(loop_cursor)

Expand Down Expand Up @@ -899,7 +897,7 @@ def ordered_stage_expr(proc, expr_cursors, new_buff_name, precision, n_lifts=1):
expr_cursors = [proc.forward(c) for c in expr_cursors]
original_stmt = get_statement(expr_cursors[0])

proc = bind_expr(proc, expr_cursors, new_buff_name, cse=True)
proc = bind_expr(proc, expr_cursors, new_buff_name)
original_stmt = proc.forward(original_stmt)
assign_cursor = original_stmt.prev()
alloc_cursor = assign_cursor.prev()
Expand Down
6 changes: 3 additions & 3 deletions src/level1/rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def rot_template(n: size, x: [R][n], y: [R][n], c: R, s: R):
def schedule_rot_stride_1(rot, params):
rot = generate_stride_1_proc(rot, params.precision)
loop_cursor = rot.find_loop("i")
rot = bind_expr(rot, rot.find("y[_]", many=True), "yReg", cse=True)
rot = bind_expr(rot, rot.find("y[_]", many=True), "yReg")
rot = set_precision(rot, "yReg", params.precision)
rot = bind_expr(rot, rot.find("s_", many=True), "sReg", cse=True)
rot = bind_expr(rot, rot.find("s_", many=True), "sReg")
rot = set_precision(rot, "sReg", params.precision)
rot = bind_expr(rot, rot.find("c_", many=True), "cReg", cse=True)
rot = bind_expr(rot, rot.find("c_", many=True), "cReg")
rot = set_precision(rot, "cReg", params.precision)
rot = blas_vectorize(rot, loop_cursor, params)
loop_cursor = rot.forward(loop_cursor)
Expand Down
2 changes: 1 addition & 1 deletion src/level1/rotm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def schedule_rotm_stride_1(rotm, params):
rotm = generate_stride_1_proc(rotm, params.precision)

loop_cursor = rotm.find_loop("i")
rotm = bind_expr(rotm, rotm.find("y[_]", many=True), "yReg", cse=True)
rotm = bind_expr(rotm, rotm.find("y[_]", many=True), "yReg")
rotm = set_precision(rotm, "yReg", params.precision)
rotm = blas_vectorize(rotm, loop_cursor, params)
loop_cursor = rotm.forward(loop_cursor)
Expand Down

0 comments on commit 060c5e5

Please sign in to comment.