diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2008fe5e47b8..d6af9936ca40 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -167,7 +167,7 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -Array CreatePassList(bool disable_loop_partition, bool for_te_schedule) { +Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); @@ -214,17 +214,14 @@ Array CreatePassList(bool disable_loop_partition, bool for Array pass_list = user_lower_phase0; // PHASE 1 - if (for_te_schedule) { - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - } else { - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::FlattenBuffer()); - } + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -288,6 +285,10 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + // Mark this schedule as being converted from an TE schedule. Makes sure that + // the correct TE passes are run. + f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true)); + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) { @@ -311,7 +312,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") }); IRModule LowerModule(IRModule mod, bool simple_mode) { - Array pass_list = CreatePassList(simple_mode, false); + Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(std::move(mod), pass_list); } @@ -331,7 +332,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_ IRModule mod = IRModule(Map({{GlobalVar(name), f}})); // Get the pass list - Array pass_list = CreatePassList(simple_mode, false); + Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(std::move(mod), pass_list); } @@ -353,7 +354,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const std const std::unordered_map& binds, bool simple_mode) { IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); // Get the legacy TE pass list - Array pass_list = CreatePassList(simple_mode, true); + Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(mod, pass_list); } diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 2063fc7cad6a..439d0ff17255 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -170,7 +170,9 @@ PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, } body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body)); - return tir::PrimFunc(params, body, VoidType(), buffer_map); + // We mark this PrimFunc as coming from a TE schedule + return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map), "from_legacy_te_schedule", + Bool(true)); } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index bd1fa9bce836..b1a4fd45ef0d 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -32,6 +32,7 @@ #include "../../support/arena.h" #include "../../support/utils.h" #include "../schedule/utils.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -452,11 +453,16 @@ class BufferCompactor : public StmtExprMutator { }; PrimFunc CompactBufferAllocation(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - std::unordered_map region = - BufferAccessRegionCollector::Collect(f); - fptr->body = BufferCompactor::Compact(f, region); - return f; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + std::unordered_map region = + BufferAccessRegionCollector::Collect(f); + fptr->body = BufferCompactor::Compact(f, region); + return f; + } else { + return f; + } } namespace transform { diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index 4c5e1dd5125b..f7629d100645 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -25,6 +25,8 @@ #include #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -83,9 +85,14 @@ class OpaqueBlockConverter : public StmtExprMutator { }; PrimFunc ConvertBlocksToOpaque(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = OpaqueBlockConverter::Substitute(f); - return f; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockConverter::Substitute(f); + return f; + } else { + return f; + } } namespace transform { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index f1f914fa2f5c..85c412346056 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -28,6 +28,7 @@ #include #include "../../support/utils.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -151,9 +152,14 @@ class BufferFlattener : public StmtExprMutator { }; PrimFunc FlattenBuffer(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = BufferFlattener::Flatten(f); - return f; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = BufferFlattener::Flatten(f); + return f; + } else { + return f; + } } namespace transform { diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc index 4ce9c7639b77..f20577e3a01b 100644 --- a/src/tir/transforms/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -31,6 +31,8 @@ #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -96,9 +98,14 @@ namespace transform { Pass InjectPrefetch() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = PrefetchInjector()(std::move(n->body)); - return f; + // Only apply this pass to TIR from TE schedules + if (IsFromLegacyTESchedule(f)) { + auto* n = f.CopyOnWrite(); + n->body = PrefetchInjector()(std::move(n->body)); + return f; + } else { + return f; + } }; return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); } diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 7248bd4e663f..a41905c148bf 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -244,5 +244,10 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region return result; } +Bool IsFromLegacyTESchedule(PrimFunc f) { + Optional from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false)); + return from_legacy_te_schedule.value(); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 79c5f0609243..9be18b790b41 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -213,6 +214,16 @@ Array ConvertIndices(const MatchBufferRegion& match_buffer, */ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); +/*! + * \brief Check if a given PrimFunc originated from a TE schedule. + * + * Internally this checks for the `from_legacy_te_schedule` attr of the PrimFunc. + * + * \param f PrimFunc to check + * \return Whether or not the PrimFunc was created from a te schedule + */ +Bool IsFromLegacyTESchedule(PrimFunc f); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index c8aca5195085..d8621ac3b3e6 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -25,6 +25,8 @@ #include #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -63,9 +65,14 @@ class InitBlockLower : public StmtMutator { }; PrimFunc LowerInitBlock(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - fptr->body = InitBlockLower()(std::move(fptr->body)); - return func; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(func)) { + auto fptr = func.CopyOnWrite(); + fptr->body = InitBlockLower()(std::move(fptr->body)); + return func; + } else { + return func; + } } namespace transform { diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 949c955b2dfe..bee11ad72280 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -26,6 +26,8 @@ #include #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -145,10 +147,15 @@ class BufferAllocationLocator : public StmtExprMutator { }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - BufferAllocationLocator locator(func); - fptr->body = locator(fptr->body); - return func; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(func)) { + auto fptr = func.CopyOnWrite(); + BufferAllocationLocator locator(func); + fptr->body = locator(fptr->body); + return func; + } else { + return func; + } } namespace transform { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 38b3a77b1a0c..2c32cc7f0883 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -500,13 +500,19 @@ class StorageFlattener : public StmtExprMutator { }; PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { - auto fptr = func.CopyOnWrite(); - - IRVisitorWithAnalyzer bound_analyzer; - bound_analyzer(fptr->body); - fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, - &bound_analyzer)(std::move(fptr->body)); - return func; + // Only apply this pass to TIR from TE schedules + Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); + if (from_legacy_te_schedule.value()) { + auto fptr = func.CopyOnWrite(); + + IRVisitorWithAnalyzer bound_analyzer; + bound_analyzer(fptr->body); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, + &bound_analyzer)(std::move(fptr->body)); + return func; + } else { + return func; + } } namespace transform { diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 4505a7bed244..e5528a8c4756 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -50,6 +50,25 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: @tvm.script.tir class LoweredModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # function attr dict + tir.func_attr( + {"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True} + ) + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + # body + for x, y in tir.grid(128, 128): + C.data[x * 128 + y] = 0.0 + for k in tir.serial(0, 128): + C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load( + "float32", A.data, x * 128 + k + ) * tir.load("float32", B.data, y * 128 + k) + + +@tvm.script.tir +class LoweredTIRModule: def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "main", "tir.noalias": True}) @@ -83,7 +102,7 @@ def test_lower_build_te_schedule(): def test_lower_build_tir_func(): # check lowering ir_mod = tvm.lower(matmul) - tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule()) # check building mod = tvm.build(matmul, target="llvm") _check_module_with_numpy(mod) @@ -95,7 +114,7 @@ def test_lower_build_tir_module(): ir_mod = IRModule({"main": func}) # check lowering lowered_mod = tvm.lower(ir_mod) - tvm.ir.assert_structural_equal(lowered_mod, LoweredModule()) + tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule()) # check building mod = tvm.build(ir_mod, target="llvm") _check_module_with_numpy(mod) @@ -103,8 +122,8 @@ def test_lower_build_tir_module(): def test_lower_build_lowered_module(): # check lowering - ir_mod = tvm.lower(LoweredModule()) - tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + ir_mod = tvm.lower(LoweredTIRModule()) + tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule()) # check building mod = tvm.build(ir_mod, target="llvm") _check_module_with_numpy(mod) diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index a469c6d0cc13..fb53b420f4ce 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -371,6 +371,15 @@ def test_match_buffer(): _check(match_buffer_func, compacted_match_buffer_func) +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation should do nothing on TE + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index 38fe1c967456..708f1af0c064 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -73,5 +73,14 @@ def test_elementwise(): _check(elementwise_func, substituted_elementwise_func) +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.ConvertBlocksToOpaque()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # ConvertBlocksToOpaque should do nothing on TE + + if __name__ == "__main__": test_elementwise() diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 6929a329ac0f..3b2b3cf2f55b 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -234,6 +234,15 @@ def test_multi_alloc(): _check(compacted_multi_alloc_func, flattened_multi_alloc_func) +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.FlattenBuffer()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE + + if __name__ == "__main__": test_elementwise() test_gpu_workload() diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index badf5e0e4d10..8499c9334e46 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty # pylint: disable=no-self-argument @@ -85,6 +85,15 @@ def test_lower_match_buffer(): tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True) +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.LowerInitBlock()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # LowerInitBlock should do nothing on TE + + if __name__ == "__main__": test_lower_reduction() test_lower_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 022c964df0c7..72a2f5ebc240 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -149,6 +149,17 @@ def test_match_buffer_allocation(): _check(match_buffer_func, transformed_match_buffer_func) +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(orig_mod) + tvm.ir.assert_structural_equal( + mod, orig_mod + ) # PlanAndUpdateBufferAllocationLocation should do nothing on TE + + if __name__ == "__main__": test_elementwise() test_locate_buffer_allocation() diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 0e9ab862a9c8..1dd4a4852938 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,6 +16,8 @@ # under the License. import tvm from tvm import te +from tvm.script import ty +from tvm.relay import GlobalVar def test_flatten2(): @@ -102,7 +104,9 @@ def test_flatten_double_buffer(): stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt)) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C], stmt).with_attr("from_legacy_te_schedule", True) + ) with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": {"split_loop": 2}}): mod = tvm.transform.Sequential( @@ -130,6 +134,21 @@ def count_sync(op): assert count[0] == 4 +@tvm.script.tir +def tir_func(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [2, 2]) + B = tir.match_buffer(a, [2, 2]) + A[0, 1] = B[1, 1] + + +def test_flatten_tir(): + orig_mod = tvm.IRModule({GlobalVar("main"): tir_func}) + mod = tvm.tir.transform.StorageFlatten(64)(orig_mod) + tvm.ir.assert_structural_equal( + orig_mod, mod + ) # StorageFlatten should do nothing to TIR functions + + if __name__ == "__main__": test_flatten2() test_flatten_storage_align()