diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 5b16be16fca7..b41bdfec06a7 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -16,5 +16,6 @@ # under the License. """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc +from .disallow_dynamic_loop import DisallowDynamicLoop from .rewrite_reduction_block import RewriteReductionBlock from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py new file mode 100644 index 000000000000..5515d288e0e7 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that checks if the IRModule has any loop with non-constant extent""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.DisallowDynamicLoop") +class DisallowDynamicLoop(Postproc): + """A postprocessor that checks if the IRModule has any loop with non-constant extent""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc new file mode 100644 index 000000000000..85a81f10fdcd --- /dev/null +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Check if an IRModule has any dynamic loop. */ +struct DynamicExtentFinder : private StmtVisitor { + public: + static bool Find(const IRModule& mod) { + DynamicExtentFinder finder; + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (const auto* prim_func = func.as()) { + finder(prim_func->body); + if (finder.found_) { + return true; + } + } + } + return false; + } + + private: + void VisitStmt_(const ForNode* loop) final { + if (!loop->extent->IsInstance()) { + found_ = true; + } else { + StmtVisitor::VisitStmt_(loop); + } + } + + void VisitStmt(const Stmt& stmt) final { + if (!found_) { + StmtVisitor::VisitStmt(stmt); + } + } + + bool found_ = false; +}; + +} // namespace tir + +namespace meta_schedule { + +/*! \brief Check if the IRModule has any loop with non-constant extent. */ +class DisallowDynamicLoopNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } + + static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; + TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); +}; + +Postproc Postproc::DisallowDynamicLoop() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") + .set_body_typed(Postproc::DisallowDynamicLoop); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py new file mode 100644 index 000000000000..d27e3e61084f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import DisallowDynamicLoop +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + DisallowDynamicLoop(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DynamicLoop: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + for k in T.serial(0, i): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_postproc_disallow_dynamic_loops(): + mod = Matmul + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_disallow_dynamic_loops_fail(): + mod = DynamicLoop + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + test_postproc_disallow_dynamic_loops() + test_postproc_disallow_dynamic_loops_fail()