Skip to content

Commit

Permalink
[M3c][Meta Schedule] Add Trace Correctness Test for PostOrderApply (#492
Browse files Browse the repository at this point in the history
)

* Add schedule rule c++ side.

* Add postproc c++ side.

* Add schedule rule python side.

* Add mutator python side.

* Add Postproc python side.

* Put src to different files in python.

* Try to fix bprint.

* Get str function work.

* Add PostOrderApply c++ side.

* Remove IRModule from SpaceGenerator's generate function signature.

* Refactor PostOrderApply.

* Fix PostOrderApply Design.

* Change StmtSRef to BlockRV.

* Add IRModule back to SpaceGenerator.

* Add trace  correctness test for post order apply.
  • Loading branch information
zxybazh authored and junrushao committed Nov 5, 2021
1 parent b36568d commit f151041
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tvm.meta_schedule.space_generator import PostOrderApply
from tvm.meta_schedule.schedule_rule import PyScheduleRule
from tvm.meta_schedule.utils import _get_hex_address
from tvm.tir.schedule import trace
from tvm.tir.schedule.trace import Trace


Expand Down Expand Up @@ -274,6 +275,26 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
sch.compute_inline(block)
return [sch]

def correct_trace(a, b, c, d):
return "\n".join(
[
'b0 = sch.get_block(name="A", func_name="main")',
'b1 = sch.get_block(name="B", func_name="main")',
'b2 = sch.get_block(name="C", func_name="main")',
"sch.compute_inline(block=b1)",
'b3 = sch.get_block(name="A", func_name="main")',
'b4 = sch.get_block(name="C", func_name="main")',
"l5, l6 = sch.get_loops(block=b4)",
"l7, l8 = sch.split(loop=l5, factors=" + str(a) + ")",
"l9, l10 = sch.split(loop=l6, factors=" + str(b) + ")",
"sch.reorder(l7, l9, l8, l10)",
"l11, l12 = sch.get_loops(block=b3)",
"l13, l14 = sch.split(loop=l11, factors=" + str(c) + ")",
"l15, l16 = sch.split(loop=l12, factors=" + str(d) + ")",
"sch.reorder(l13, l15, l14, l16)",
]
)

mod = TrinityMatmul
context = TuneContext(
mod=mod,
Expand All @@ -291,6 +312,12 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
match="ScheduleError: An error occurred in the schedule primitive 'get-block'.",
):
sch.get_block("B", "main")
assert (
str(sch.trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512])
or str(sch.trace) == correct_trace([2, 512], [2, 512], [2, 512], [2, 512])
or str(sch.trace) == correct_trace([16, 64], [64, 16], [16, 64], [64, 16])
or str(sch.trace) == correct_trace([2, 512], [2, 512], [16, 64], [64, 16])
)


if __name__ == "__main__":
Expand Down

0 comments on commit f151041

Please sign in to comment.