From 616f8572c6161f9c92402ada8d531ad18d59a93d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 19 Dec 2021 00:07:11 -0800 Subject: [PATCH] [M3c][MetaScheduler] Add ScheduleRule class & PostOrderApply space generator. (#9761) * Add ScheduleRule class & PostOrderApply space generator. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng * Fix comments & docs. * Fix for mypy. * Retrigger CI. * remove get_hex_address Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/schedule_rule.h | 195 ++++++++++ include/tvm/meta_schedule/search_strategy.h | 6 +- include/tvm/meta_schedule/space_generator.h | 16 +- include/tvm/meta_schedule/tune_context.h | 5 + include/tvm/tir/schedule/schedule.h | 6 + python/tvm/meta_schedule/__init__.py | 1 + .../meta_schedule/schedule_rule/__init__.py | 19 + .../schedule_rule/schedule_rule.py | 96 +++++ .../meta_schedule/space_generator/__init__.py | 2 +- .../space_generator/post_order_apply.py | 36 ++ .../space_generator/space_generator.py | 5 +- python/tvm/meta_schedule/tune_context.py | 27 +- .../schedule_rule/schedule_rule.cc | 55 +++ .../search_strategy/replay_trace.cc | 8 +- .../space_generator/post_order_apply.cc | 155 ++++++++ .../space_generator/space_generator_union.cc | 4 +- src/meta_schedule/tune_context.cc | 7 +- src/meta_schedule/utils.h | 4 + src/tir/schedule/concrete_schedule.h | 14 + .../test_meta_schedule_post_order_apply.py | 342 ++++++++++++++++++ 20 files changed, 961 insertions(+), 42 deletions(-) create mode 100644 include/tvm/meta_schedule/schedule_rule.h create mode 100644 python/tvm/meta_schedule/schedule_rule/__init__.py create mode 100644 python/tvm/meta_schedule/schedule_rule/schedule_rule.py create mode 100644 python/tvm/meta_schedule/space_generator/post_order_apply.py create mode 100644 src/meta_schedule/schedule_rule/schedule_rule.cc create mode 100644 src/meta_schedule/space_generator/post_order_apply.cc create mode 100644 tests/python/unittest/test_meta_schedule_post_order_apply.py diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h new file mode 100644 index 000000000000..8313da067f09 --- /dev/null +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Rules to modify a block in a schedule. */ +class ScheduleRuleNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~ScheduleRuleNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a schedule rule to the specific block in the given schedule. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + virtual runtime::Array Apply(const tir::Schedule& sch, + const tir::BlockRV& block) = 0; + + static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; + TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); +}; + +/*! \brief The schedule rule with customized methods on the python-side. */ +class PyScheduleRuleNode : public ScheduleRuleNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `Apply` method. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + using FApply = + runtime::TypedPackedFunc(const tir::Schedule&, const tir::BlockRV&)>; + /*! + * \brief Get the schedule rule as string with name. + * \return The string of the schedule rule. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { + ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; + return this->f_apply(sch, block); + } + + static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); +}; + +/*! + * \brief Managed reference to ScheduleRuleNode + * \sa ScheduleRuleNode + */ +class ScheduleRule : public runtime::ObjectRef { + public: + /*! + * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions + * \param into_producer If allows to inline a block into its producer + * \param into_consumer If allows to inline a block into its consumer + * \param into_cache_only If it only allows to inline into a block generated by cache_read/write + * \param inline_const_tensor Always inline constant tensors + * \param disallow_if_then_else Always disallow if-then-else-like constructs + * \param require_ordered Always require the read-to-write mapping to be ordered + * \param require_injective Always require the read-to-write mapping to be injective + * \param disallow_op The operators that are disallowed in auto inline + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule AutoInline(bool into_producer, // + bool into_consumer, // + bool into_cache_only, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op); + /*! + * \brief Create a mega rule: multi-level tiling with data reuse + * \param structure The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: + * - NullOpt on CPU + * - [blockIdx.x, vthread.x, threadIdx.x] on GPU + * \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // + Optional> tile_binds, // + bool use_tensor_core, // + Optional max_innermost_factor, // + Optional vector_load_max_len, // + Optional> reuse_read, // + Optional> reuse_write); + /*! + * \brief A rule that randomly select a compute-at location for a free block + * \return The rule created + */ + TVM_DLL static ScheduleRule RandomComputeLocation(); + /*! + * \brief Mark parallelize, vectorize and unroll to each block correspondingly + * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the + * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable + * parallelism. + * \param max_vectorize_extent The maximum extent to be vectorized. + * It sets the uplimit of the CPU vectorization. Use -1 to disable vectorization. + * \param unroll_max_steps The maximum number of unroll steps to be done. + * Use an empty array to disable unroll. + * \param unroll_explicit Whether to explicitly unroll the loop, or just add a unroll pragma. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + Array unroll_max_steps, // + bool unroll_explicit); + /*! + * \brief Create a schedule rule with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The schedule rule created. + */ + TVM_DLL static ScheduleRule PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 0f3e9298d11a..e1c68c8a1a11 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -104,10 +104,10 @@ class SearchStrategyNode : public runtime::Object { /*! * \brief Initialize the search strategy with tuning context. - * \param tune_context The tuning context for initialization. + * \param context The tuning context for initialization. * \note This method is supposed to be called only once before every other method. */ - virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; /*! * \brief Pre-tuning for the search strategy. @@ -146,7 +146,7 @@ class PySearchStrategyNode : public SearchStrategyNode { public: /*! * \brief The function type of `InitializeWithTuneContext` method. - * \param tune_context The tuning context for initialization. + * \param context The tuning context for initialization. */ using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index f314af62b6bd..7aff6839dc55 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -71,10 +71,10 @@ class SpaceGeneratorNode : public Object { /*! * \brief Initialize the design space generator with tuning context. - * \param tune_context The tuning context for initialization. + * \param context The tuning context for initialization. * \note This method is supposed to be called only once before every other method. */ - virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; /*! * \brief Generate design spaces given a module. @@ -92,7 +92,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { public: /*! * \brief The function type of `InitializeWithTuneContext` method. - * \param tune_context The tuning context for initialization. + * \param context The tuning context for initialization. */ using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! @@ -112,10 +112,10 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { // `f_generate_design_space` is not visited } - void InitializeWithTuneContext(const TuneContext& tune_context) final { + void InitializeWithTuneContext(const TuneContext& context) final { ICHECK(f_initialize_with_tune_context != nullptr) << "PySpaceGenerator's InitializeWithTuneContext !"; - f_initialize_with_tune_context(tune_context); + f_initialize_with_tune_context(context); } Array GenerateDesignSpace(const IRModule& mod) final { @@ -153,6 +153,12 @@ class SpaceGenerator : public ObjectRef { * \return The design space generator created. */ TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); + /*! + * \brief Create a design space generator that generates design spaces by applying schedule rules + * to blocks in post-DFS order. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator PostOrderApply(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index db72328c91c3..559f2da7f9d9 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -38,6 +38,8 @@ class TuneContextNode : public runtime::Object { Optional space_generator; /*! \brief The search strategy. */ Optional search_strategy; + /*! \brief The schedule rules. */ + Array sch_rules; /*! \brief The name of the tuning task. */ Optional task_name; /*! \brief The random state. */ @@ -57,6 +59,7 @@ class TuneContextNode : public runtime::Object { v->Visit("target", &target); v->Visit("space_generator", &space_generator); v->Visit("search_strategy", &search_strategy); + v->Visit("sch_rules", &sch_rules); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); @@ -81,6 +84,7 @@ class TuneContext : public runtime::ObjectRef { * \param target The target to be tuned for. * \param space_generator The design space generator. * \param search_strategy The search strategy. + * \param sch_rules The schedule rules. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -89,6 +93,7 @@ class TuneContext : public runtime::ObjectRef { Optional target, // Optional space_generator, // Optional search_strategy, // + Optional> sch_rules, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 90ebf36723a3..1b64a715ba18 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -155,6 +155,12 @@ class ScheduleNode : public runtime::Object { * \return The corresponding loop sref */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Check the existance of a specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return Whether the corresponding block exists + */ + virtual bool HasBlock(const BlockRV& block_rv) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 47b3dda5a36e..8b6672ccc371 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -21,5 +21,6 @@ from . import runner from . import space_generator from . import search_strategy +from . import schedule_rule from . import integration from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py new file mode 100644 index 000000000000..b90780d5bfdb --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -0,0 +1,19 @@ +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.schedule_rule package. +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py new file mode 100644 index 000000000000..b995e5acb6fc --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule, BlockRV + +from ..utils import _get_hex_address, check_override +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.ScheduleRule") +class ScheduleRule(Object): + """Rules to modify a block in a schedule.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the schedule rule with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the schedule rule. + """ + _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + """Apply a schedule rule to the specific block in the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be modified. + block : BlockRV + The specific block to apply the schedule rule. + + Returns + ------- + design_spaces : List[Schedule] + The list of schedules generated by applying the schedule rule. + """ + return _ffi_api.ScheduleRuleApply( # type: ignore # pylint: disable=no-member + self, sch, block + ) + + +@register_object("meta_schedule.PyScheduleRule") +class PyScheduleRule(ScheduleRule): + """An abstract schedule rule with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, ScheduleRule) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, ScheduleRule) + def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: + return self.apply(sch, block) + + def f_as_string() -> str: + return self.__str__() + + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRulePyScheduleRule, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py index af759d43b34a..fc08cd491de7 100644 --- a/python/tvm/meta_schedule/space_generator/__init__.py +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -19,7 +19,7 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ - from .space_generator import SpaceGenerator, PySpaceGenerator from .space_generator_union import SpaceGeneratorUnion from .schedule_fn import ScheduleFn +from .post_order_apply import PostOrderApply diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py new file mode 100644 index 000000000000..80f372a448f5 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Post Order Apply Space Generator.""" + + +from tvm._ffi import register_object +from .space_generator import SpaceGenerator +from .. import _ffi_api + + +@register_object("meta_schedule.PostOrderApply") +class PostOrderApply(SpaceGenerator): + """ + PostOrderApply is the design space generator that generates design spaces by applying schedule + rules to blocks in post-DFS order. + """ + + def __init__(self): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index e37fd14ba440..2172613ce1e6 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -36,10 +36,7 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 0f3cfac1a85f..99b8c7e869cd 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,7 +16,7 @@ # under the License. """Meta Schedule tuning context.""" -from typing import Optional, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING from tvm import IRModule from tvm._ffi import register_object @@ -29,6 +29,7 @@ if TYPE_CHECKING: from .space_generator import SpaceGenerator from .search_strategy import SearchStrategy + from .schedule_rule import ScheduleRule @register_object("meta_schedule.TuneContext") @@ -50,6 +51,8 @@ class TuneContext(Object): The design space generator. search_strategy : Optional[SearchStrategy] = None The search strategy. + sch_rules: Optional[List[ScheduleRule]] = None, + The schedule rules. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -80,30 +83,11 @@ def __init__( target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, + sch_rules: Optional[List["ScheduleRule"]] = None, task_name: Optional[str] = None, rand_state: int = -1, num_threads: Optional[int] = None, ): - """Constructor. - - Parameters - ---------- - mod : Optional[IRModule] = None - The workload to be optimized. - target : Optional[Target] = None - The target to be optimized for. - space_generator : Optional[SpaceGenerator] = None - The design space generator. - search_strategy : Optional[SearchStrategy] = None - The search strategy. - task_name : Optional[str] = None - The name of the tuning task. - rand_state : int = -1 - The random state. - Need to be in integer in [1, 2^31-1], -1 means using random number. - num_threads : Optional[int] = None - The number of threads to be used, None means using the logical cpu count. - """ if num_threads is None: num_threads = cpu_count() @@ -113,6 +97,7 @@ def __init__( target, space_generator, search_strategy, + sch_rules, task_name, rand_state, num_threads, diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc new file mode 100644 index 000000000000..f80f684dafa8 --- /dev/null +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +ScheduleRule ScheduleRule::PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return ScheduleRule(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); +TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") + .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") + .set_body_method(&ScheduleRuleNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") + .set_body_typed(ScheduleRule::PyScheduleRule); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1c83aee8c0fd..200eca34133d 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -73,11 +73,11 @@ class ReplayTraceNode : public SearchStrategyNode { static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); - void InitializeWithTuneContext(const TuneContext& tune_context) final { - this->mod_ = tune_context->mod.value(); + void InitializeWithTuneContext(const TuneContext& context) final { + this->mod_ = context->mod.value(); this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); - this->num_threads_ = tune_context->num_threads; - this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->num_threads_ = context->num_threads; + this->rand_state_ = ForkSeed(&context->rand_state); this->state_.reset(); } diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc new file mode 100644 index 000000000000..bc616327eb3b --- /dev/null +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Collecting all the non-root blocks */ +class BlockCollector : public tir::StmtVisitor { + public: + static Array Collect(const tir::Schedule& sch) { // + return BlockCollector(sch).Run(); + } + + private: + /*! \brief Entry point */ + Array Run() { + for (const auto& kv : sch_->mod()->functions) { + const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function + const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function + if (const auto* func = base_func.as()) { + func_name_ = gv->name_hint; + block_names_.clear(); + blocks_to_collect_.clear(); + VisitStmt(func->body); + for (const String& block_name : blocks_to_collect_) { + results_.push_back(sch_->GetBlock(block_name, func_name_)); + } + } + } + return results_; + } + /*! \brief Constructor */ + explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {} + /*! \brief Override the Stmt visiting behaviour */ + void VisitStmt_(const tir::BlockNode* block) override { + tir::StmtVisitor::VisitStmt_(block); + CHECK(block_names_.count(block->name_hint) == 0) + << "Duplicated block name " << block->name_hint << " in function " << func_name_ + << " not supported!"; + block_names_.insert(block->name_hint); + blocks_to_collect_.push_back(block->name_hint); + } + + /*! \brief The schedule to be collected */ + const tir::Schedule& sch_; + /*! \brief The set of func name and block name pair */ + std::unordered_set block_names_; + /* \brief The list of blocks to collect in order */ + Array blocks_to_collect_; + /*! \brief Function name & blocks of collection */ + Array results_; + /*! \brief Name of the current PrimFunc */ + String func_name_; +}; + +/*! + * \brief Design Space Generator that generates design spaces by applying schedule rules to blocks + * in post-DFS order. + * */ +class PostOrderApplyNode : public SpaceGeneratorNode { + public: + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The schedule rules to be applied in order. */ + Array sch_rules_{nullptr}; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `rand_state_` is not visited + // `sch_rules_` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + this->rand_state_ = ForkSeed(&context->rand_state); + CHECK(context->sch_rules.defined()) + << "ValueError: Schedules rules not given in PostOrderApply!"; + this->sch_rules_ = context->sch_rules; + } + + Array GenerateDesignSpace(const IRModule& mod_) final { + using ScheduleAndUnvisitedBlocks = std::pair>; + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/mod_, // + /*rand_state=*/ForkSeed(&this->rand_state_), // + /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + + std::vector stack; + Array result{sch}; + // Enumerate the schedule rules first because you can + // always concat multiple schedule rules as one + Array all_blocks = BlockCollector::Collect(sch); + for (ScheduleRule sch_rule : sch_rules_) { + for (const tir::Schedule& sch : result) { + stack.emplace_back(sch, all_blocks); + } + result.clear(); + + while (!stack.empty()) { + // get the stack.top() + tir::Schedule sch; + Array blocks; + std::tie(sch, blocks) = stack.back(); + stack.pop_back(); + // if all blocks are visited + if (blocks.empty()) { + result.push_back(sch); + continue; + } + // otherwise, get the last block that is not visited + tir::BlockRV block_rv = blocks.back(); + blocks.pop_back(); + if (sch->HasBlock(block_rv)) { + Array applied = sch_rule->Apply(sch, /*block=*/block_rv); + for (const tir::Schedule& sch : applied) { + stack.emplace_back(sch, blocks); + } + } else { + stack.emplace_back(sch, blocks); + } + } + } + return result; + } + static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; + TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); +}; + +SpaceGenerator SpaceGenerator::PostOrderApply() { + ObjectPtr n = make_object(); + return SpaceGenerator(n); +} + +TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") + .set_body_typed(SpaceGenerator::PostOrderApply); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 9c2e3eeabe09..6ea61824f932 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -29,10 +29,10 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("space_generators", &space_generators); } - void InitializeWithTuneContext(const TuneContext& tune_context) final { + void InitializeWithTuneContext(const TuneContext& context) final { // Initialize each space generator. for (const SpaceGenerator& space_generator : space_generators) { - space_generator->InitializeWithTuneContext(tune_context); + space_generator->InitializeWithTuneContext(context); } } diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 9fc9272e33ac..ac85d43e7987 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -38,6 +38,7 @@ TuneContext::TuneContext(Optional mod, Optional target, // Optional space_generator, // Optional search_strategy, // + Optional> sch_rules, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -46,6 +47,7 @@ TuneContext::TuneContext(Optional mod, n->target = target; n->space_generator = space_generator; n->search_strategy = search_strategy; + n->sch_rules = sch_rules.value_or({}); n->task_name = task_name; if (rand_state == -1) { rand_state = std::random_device()(); @@ -65,11 +67,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional target, // Optional space_generator, // Optional search_strategy, // + Optional> sch_rules, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, task_name, rand_state, - num_threads); + return TuneContext(mod, target, space_generator, search_strategy, sch_rules, task_name, + rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 9b0a37160a13..f4f95755408c 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -46,6 +47,9 @@ namespace tvm { namespace meta_schedule { +/*! \brief The type of the random state */ +using TRandState = support::LinearCongruentialEngine::TRandState; + /*! * \brief Read lines from a json file. * \param path The path to the json file. diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index d7056b6b061f..d420728a9e3c 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -72,6 +72,7 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline bool HasBlock(const BlockRV& block_rv) const final; inline Array GetSRefs(const Array& rvs) const; inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } @@ -205,6 +206,19 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { return this->analyzer_->Simplify(transformed); } +inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { + auto it = this->symbol_table_.find(block_rv); + if (it == this->symbol_table_.end()) { + return false; + } + const ObjectRef& obj = (*it).second; + const auto* sref = obj.as(); + if (sref == nullptr || sref->stmt == nullptr) { + return false; + } + return true; +} + inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py new file mode 100644 index 000000000000..b78e67817ebf --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -0,0 +1,342 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import math +import sys +from typing import List + +import pytest +import tvm +from tvm.error import TVMError +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir.schedule import BlockRV, Schedule + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DuplicateMatmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class TrinityMatmul: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.alloc_buffer((1024, 1024), "float32") + C = T.alloc_buffer((1024, 1024), "float32") + D = T.match_buffer(d, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1024, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 3.0 + for i, j in T.grid(1024, 1024): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = C[vi, vj] * 5.0 + + +@tvm.script.ir_module +class TrinityMatmulProcessedForReference: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024], dtype="float32") + D = T.match_buffer(d, [1024, 1024], dtype="float32") + # body + # with tir.block("root") + B = T.alloc_buffer([1024, 1024], dtype="float32") + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("A"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("C"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([B[vi, vj]]) + T.writes([D[vi, vj]]) + D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _is_root(sch: Schedule, block: BlockRV) -> bool: + return sch.get_sref(block).parent is None + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +class WowSoFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [new_sch] + + +class DoubleScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result = [new_sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result.append(new_sch) + return result + + +class ReorderScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, i_0, j_0) + result = [new_sch] + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_3, i_0, j_0, j_1, k_0, i_2, j_2, k_1, i_3) + result.append(new_sch) + return result + + +def test_meta_schedule_post_order_apply(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Test Task", + sch_rules=[WowSoFancyScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + assert not tvm.ir.structural_equal(schs[0].mod, mod) + _check_correct(schs[0]) + + +def test_meta_schedule_post_order_apply_double(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 2 + for sch in schs: + assert not tvm.ir.structural_equal(sch.mod, mod) + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_multiple(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + assert not tvm.ir.structural_equal(sch.mod, mod) + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_duplicate_matmul(): + mod = DuplicateMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Duplicate Matmul Task", + sch_rules=[WowSoFancyScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises( + TVMError, + match=r".*TVMError: Check failed: \(block_names_.count\(block->name_hint\) == 0\)" + r" is false: Duplicated block name matmul in function main not supported!", + ): + post_order_apply.generate_design_space(mod) + + +def test_meta_schedule_post_order_apply_remove_block(): + class TrinityDouble(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) + j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result = [new_sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) + j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result.append(new_sch) + return result + + class RemoveBlock(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + sch = sch.copy() + if sch.get(block).name_hint == "B": + sch.compute_inline(block) + return [sch] + + def correct_trace(a, b, c, d): + return "\n".join( + [ + 'b0 = sch.get_block(name="A", func_name="main")', + 'b1 = sch.get_block(name="B", func_name="main")', + 'b2 = sch.get_block(name="C", func_name="main")', + "sch.compute_inline(block=b1)", + "l3, l4 = sch.get_loops(block=b2)", + "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ")", + "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ")", + "sch.reorder(l5, l7, l6, l8)", + "l9, l10 = sch.get_loops(block=b0)", + "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ")", + "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ")", + "sch.reorder(l11, l13, l12, l14)", + ] + ) + + mod = TrinityMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Remove Block Task", + sch_rules=[RemoveBlock(), TrinityDouble()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + with pytest.raises( + tvm.tir.schedule.schedule.ScheduleError, + match="ScheduleError: An error occurred in the schedule primitive 'get-block'.", + ): + sch.get_block("B", "main") + sch_trace = sch.trace.simplified(True) + assert ( + str(sch_trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512]) + or str(sch_trace) == correct_trace([2, 512], [2, 512], [2, 512], [2, 512]) + or str(sch_trace) == correct_trace([16, 64], [64, 16], [16, 64], [64, 16]) + or str(sch_trace) == correct_trace([2, 512], [2, 512], [16, 64], [64, 16]) + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))