diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 0d088526694d..c1c2d4644e4a 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -588,12 +588,47 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim } } + bool all_singlepoints_outside = true; + + // Check all partitions to see if they are single points and outside `for_interval` + for (const auto& partition : finder.partitions) { + const auto& intset = partition.second; + // Only proceed if the interval set is a single point + if (intset.IsSinglePoint()) { + auto single_point = intset.PointValue(); + // Check if the single point is outside the `for_interval` + bool is_inside = analyzer_.CanProve(single_point >= for_interval.min()) && + analyzer_.CanProve(single_point <= for_interval.max()); + if (is_inside) { + // If any single point is inside, this is an error condition + LOG(ERROR) << "unexpected case happened."; + all_singlepoints_outside = false; + break; + } + } else { + // If there is any intset that is not a single point, follow default logic + // For now, we set all_singlepoints_outside to false to indicate default logic was used + all_singlepoints_outside = false; + break; + } + } + + if (all_singlepoints_outside) { + // If all single points are outside `for_interval`, return a nothing interval and false + return {IntSet::Nothing(), ExpressionSet(), false}; + } + // we couldn't find an interval in which the conditions are // provably true or false. Therefore, we can't partition the loop // based on those conds return {{}, {}, std::nullopt}; }(); + if (middle_interval.IsNothing() && opt_cond_value == false) { + // Return loop directly as it can be simplified. + return stmt; + } + if (!opt_cond_value.has_value()) { if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ && analyzer_.CanProve(max - min > 0)) { diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py b/tests/python/tir-transform/test_tir_transform_loop_partition.py index aa11ae5a5f7b..2b3f73e24f88 100644 --- a/tests/python/tir-transform/test_tir_transform_loop_partition.py +++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm import tvm.testing from tvm import te @@ -834,5 +835,232 @@ def after( assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) +@T.prim_func +def concat_func_single_point( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 63), "int8"), + T_concat: T.Buffer((28, 128), "int8"), +) -> None: + for i0 in range(28): + for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): + if i1 > 63: + T_concat[i0, i1] = placeholder[i0, i1 - 64] + elif i1 == 63: + T_concat[i0, i1] = placeholder_1[i0, i1 - 63] + else: + T_concat[i0, i1] = placeholder_2[i0, i1] + + +@T.prim_func +def expected_partitioned_concat_single_point( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 63), "int8"), + T_concat: T.Buffer((28, 128), "int8"), +): + for i0 in range(28): + T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data) + for i1 in range(63): + placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data) + T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1] + placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) + T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0] + for i1 in range(64): + placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) + T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1] + + +@T.prim_func +def concat_func_start_point_equality( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 63), "int8"), + T_concat: T.Buffer((28, 128), "int8"), +) -> None: + for i0 in range(28): + for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}): + if i1 == 0: + # Special case for i1 == 0 + T_concat[i0, i1] = placeholder_1[i0, 0] + elif i1 < 64: + # Normal case for i1 in [1, 63] + T_concat[i0, i1] = placeholder_2[i0, i1] + else: + # Case for i1 in [64, 127] + T_concat[i0, i1] = placeholder[i0, i1 - 64] + + +@T.prim_func +def concat_func_start_point_equality_expected( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 63), "int8"), + T_concat: T.Buffer((28, 128), "int8"), +): + for i0 in range(28): + T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data) + placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) + T_concat_1[i0 * 128] = placeholder_1_1[i0] + for i1 in range(63): + placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data) + T_concat_1[i0 * 128 + i1 + 1] = placeholder_2_1[i0 * 63 + i1 + 1] + for i1 in range(64): + placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) + T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1] + + +@T.prim_func +def concat_func_end_point_equality( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 63), "int8"), + T_concat: T.Buffer((28, 128), "int8"), +) -> None: + for i0 in range(28): + for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}): + if i1 == 127: + # Explicit equality check for the end point i1 == 127 + T_concat[i0, i1] = placeholder_1[i0, 0] + elif i1 >= 64: + # Case for i1 in [64, 126] + T_concat[i0, i1] = placeholder[i0, i1 - 64] + else: + # Case for i1 in [0, 63] + T_concat[i0, i1] = placeholder_2[i0, i1] + + +@T.prim_func +def concat_func_end_point_equality_expected( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 63), "int8"), + T_concat: T.Buffer((28, 128), "int8"), +): + for i0 in range(28): + T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data) + for i1 in range(64): + placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data) + T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1] + for i1 in range(63): + placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) + T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1] + placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) + T_concat_1[i0 * 128 + 127] = placeholder_1_1[i0] + + +@T.prim_func +def concat_func_edge_equalities( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 1), "int8"), + T_concat: T.Buffer((28, 66), "int8"), +) -> None: + for i0 in range(28): + for i1 in range( + 66, annotations={"pragma_loop_partition_hint": 1} + ): # Loop from 0 to 65 inclusive + if i1 == 0: + # Handle equality at the start of the range: i1 == 0 + T_concat[i0, i1] = placeholder_2[i0, 0] + elif i1 == 65: + # Handle equality at the end of the range: i1 == 65 + T_concat[i0, i1] = placeholder_1[i0, 0] + else: + # Copying from placeholder (from 0 to 63) + T_concat[i0, i1] = placeholder[i0, i1 - 1] + + +@T.prim_func +def concat_func_edge_equalities_expected( + placeholder: T.Buffer((28, 64), "int8"), + placeholder_1: T.Buffer((28, 1), "int8"), + placeholder_2: T.Buffer((28, 1), "int8"), + T_concat: T.Buffer((28, 66), "int8"), +): + for i0 in range(28): + T_concat_1 = T.Buffer((1848,), "int8", data=T_concat.data) + placeholder_2_1 = T.Buffer((28,), "int8", data=placeholder_2.data) + T_concat_1[i0 * 66] = placeholder_2_1[i0] + for i1 in range(64): + placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) + T_concat_1[i0 * 66 + i1 + 1] = placeholder_3[i0 * 64 + i1] + placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) + T_concat_1[i0 * 66 + 65] = placeholder_1_1[i0] + + +@T.prim_func +def concat_five_buffers_with_equalities( + buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0 + buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63 + buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64 + buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128 + buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129 + T_concat: T.Buffer((28, 129), "int8"), +) -> None: + for i0 in range(28): + for i1 in range(130, annotations={"pragma_loop_partition_hint": 1}): + if i1 == 0: + T_concat[i0, i1] = buffer_a[i0, 0] + elif i1 == 64: + T_concat[i0, i1] = buffer_c[i0, 0] + elif i1 == 129: + T_concat[i0, i1] = buffer_e[i0, 0] + elif i1 < 64: + T_concat[i0, i1] = buffer_b[i0, i1 - 1] + else: # i1 > 64 and i1 < 128 + T_concat[i0, i1] = buffer_d[i0, i1 - 65] + + +@T.prim_func +def concat_five_buffers_with_equalities_expected( + buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0 + buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63 + buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64 + buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128 + buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129 + T_concat: T.Buffer((28, 129), "int8"), +): + for i0 in range(28): + T_concat_1 = T.Buffer((3612,), "int8", data=T_concat.data) + buffer_a_1 = T.Buffer((28,), "int8", data=buffer_a.data) + T_concat_1[i0 * 129] = buffer_a_1[i0] + for i1 in range(63): + buffer_b_1 = T.Buffer((1764,), "int8", data=buffer_b.data) + T_concat_1[i0 * 129 + i1 + 1] = buffer_b_1[i0 * 63 + i1] + buffer_c_1 = T.Buffer((28,), "int8", data=buffer_c.data) + T_concat_1[i0 * 129 + 64] = buffer_c_1[i0] + for i1 in range(64): + buffer_d_1 = T.Buffer((1764,), "int8", data=buffer_d.data) + T_concat_1[i0 * 129 + i1 + 65] = buffer_d_1[i0 * 63 + i1] + buffer_e_1 = T.Buffer((28,), "int8", data=buffer_e.data) + T_concat_1[i0 * 129 + 129] = buffer_e_1[i0] + + +@pytest.mark.parametrize( + "origin,expected", + [ + (concat_func_single_point, expected_partitioned_concat_single_point), + (concat_func_start_point_equality, concat_func_start_point_equality_expected), + (concat_func_end_point_equality, concat_func_end_point_equality_expected), + (concat_func_edge_equalities, concat_func_edge_equalities_expected), + (concat_five_buffers_with_equalities, concat_five_buffers_with_equalities_expected), + ], +) +def test_single_point_partition(origin, expected): + origin = origin.with_attr({"global_symbol": "main"}) + expected = expected.with_attr({"global_symbol": "main"}) + mod = partition_from_scheduled_tir( + origin, + { + "tir.LoopPartition": { + "partition_const_loop": True, + "unroll_loop_with_partition_hint_no_interval": True, + } + }, + ) + assert tvm.ir.structural_equal(mod["main"], expected) + + if __name__ == "__main__": tvm.testing.main()