diff --git a/src/common/composed_schedules.py b/src/common/composed_schedules.py index c9ba2ef4..d8854b99 100644 --- a/src/common/composed_schedules.py +++ b/src/common/composed_schedules.py @@ -116,7 +116,7 @@ def stage_expr(proc, expr_cursors, new_name, precision="R", memory=DRAM, n_lifts expr_cursors = [proc.forward(c) for c in expr_cursors] enclosing_loop = get_enclosing_loop(expr_cursors[0]) stmt = get_statement(expr_cursors[0]) - proc = bind_expr(proc, expr_cursors, new_name, cse=True) + proc = bind_expr(proc, expr_cursors, new_name) stmt = proc.forward(stmt) bind_stmt = stmt.prev() alloc_stmt = bind_stmt.prev() @@ -224,9 +224,7 @@ def vectorize_to_loops(proc, loop_cursor, vec_width, memory_type, precision): """ if not isinstance(loop_cursor, ForCursor): - raise BLAS_SchedulingError( - "vectorize_to_loops loop_cursor must be a ForCursor" - ) + raise BLAS_SchedulingError("vectorize_to_loops loop_cursor must be a ForCursor") loop_cursor = proc.forward(loop_cursor) @@ -899,7 +897,7 @@ def ordered_stage_expr(proc, expr_cursors, new_buff_name, precision, n_lifts=1): expr_cursors = [proc.forward(c) for c in expr_cursors] original_stmt = get_statement(expr_cursors[0]) - proc = bind_expr(proc, expr_cursors, new_buff_name, cse=True) + proc = bind_expr(proc, expr_cursors, new_buff_name) original_stmt = proc.forward(original_stmt) assign_cursor = original_stmt.prev() alloc_cursor = assign_cursor.prev() diff --git a/src/level1/rot.py b/src/level1/rot.py index 9d477ae0..84fa0842 100644 --- a/src/level1/rot.py +++ b/src/level1/rot.py @@ -37,11 +37,11 @@ def rot_template(n: size, x: [R][n], y: [R][n], c: R, s: R): def schedule_rot_stride_1(rot, params): rot = generate_stride_1_proc(rot, params.precision) loop_cursor = rot.find_loop("i") - rot = bind_expr(rot, rot.find("y[_]", many=True), "yReg", cse=True) + rot = bind_expr(rot, rot.find("y[_]", many=True), "yReg") rot = set_precision(rot, "yReg", params.precision) - rot = bind_expr(rot, rot.find("s_", many=True), "sReg", cse=True) + rot = bind_expr(rot, rot.find("s_", many=True), "sReg") rot = set_precision(rot, "sReg", params.precision) - rot = bind_expr(rot, rot.find("c_", many=True), "cReg", cse=True) + rot = bind_expr(rot, rot.find("c_", many=True), "cReg") rot = set_precision(rot, "cReg", params.precision) rot = blas_vectorize(rot, loop_cursor, params) loop_cursor = rot.forward(loop_cursor) diff --git a/src/level1/rotm.py b/src/level1/rotm.py index a9591cd6..070c2dae 100644 --- a/src/level1/rotm.py +++ b/src/level1/rotm.py @@ -57,7 +57,7 @@ def schedule_rotm_stride_1(rotm, params): rotm = generate_stride_1_proc(rotm, params.precision) loop_cursor = rotm.find_loop("i") - rotm = bind_expr(rotm, rotm.find("y[_]", many=True), "yReg", cse=True) + rotm = bind_expr(rotm, rotm.find("y[_]", many=True), "yReg") rotm = set_precision(rotm, "yReg", params.precision) rotm = blas_vectorize(rotm, loop_cursor, params) loop_cursor = rotm.forward(loop_cursor)