diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index fe18074995..6909e936cc 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -32,6 +32,7 @@ from tvm.meta_schedule.space_generator import PostOrderApply from tvm.meta_schedule.schedule_rule import PyScheduleRule from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import trace from tvm.tir.schedule.trace import Trace @@ -274,6 +275,26 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: sch.compute_inline(block) return [sch] + def correct_trace(a, b, c, d): + return "\n".join( + [ + 'b0 = sch.get_block(name="A", func_name="main")', + 'b1 = sch.get_block(name="B", func_name="main")', + 'b2 = sch.get_block(name="C", func_name="main")', + "sch.compute_inline(block=b1)", + 'b3 = sch.get_block(name="A", func_name="main")', + 'b4 = sch.get_block(name="C", func_name="main")', + "l5, l6 = sch.get_loops(block=b4)", + "l7, l8 = sch.split(loop=l5, factors=" + str(a) + ")", + "l9, l10 = sch.split(loop=l6, factors=" + str(b) + ")", + "sch.reorder(l7, l9, l8, l10)", + "l11, l12 = sch.get_loops(block=b3)", + "l13, l14 = sch.split(loop=l11, factors=" + str(c) + ")", + "l15, l16 = sch.split(loop=l12, factors=" + str(d) + ")", + "sch.reorder(l13, l15, l14, l16)", + ] + ) + mod = TrinityMatmul context = TuneContext( mod=mod, @@ -291,6 +312,12 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: match="ScheduleError: An error occurred in the schedule primitive 'get-block'.", ): sch.get_block("B", "main") + assert ( + str(sch.trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512]) + or str(sch.trace) == correct_trace([2, 512], [2, 512], [2, 512], [2, 512]) + or str(sch.trace) == correct_trace([16, 64], [64, 16], [16, 64], [64, 16]) + or str(sch.trace) == correct_trace([2, 512], [2, 512], [16, 64], [64, 16]) + ) if __name__ == "__main__":