diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 058a2afe27..f382cd219e 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -28,10 +28,7 @@ namespace meta_schedule { class TuneContext; /*! - * \brief Rules to apply a post processing to a schedule. - * \note Post processing is designed to deal with the problem of undertermined schedule validity - * after applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the - * maximum number below 1024, X is only decided at runtime. + * \brief Rules to apply a postprocessor to a schedule. */ class PostprocNode : public runtime::Object { public: @@ -47,9 +44,9 @@ class PostprocNode : public runtime::Object { virtual void InitializeWithTuneContext(const TuneContext& context) = 0; /*! - * \brief Apply a post processing to the given schedule. + * \brief Apply a postprocessor to the given schedule. * \param sch The schedule to be post processed. - * \return Whether the post processing was successfully applied. + * \return Whether the postprocessor was successfully applied. */ virtual bool Apply(const tir::Schedule& sch) = 0; @@ -57,7 +54,7 @@ class PostprocNode : public runtime::Object { TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); }; -/*! \brief The post processing with customized methods on the python-side. */ +/*! \brief The postprocessor with customized methods on the python-side. */ class PyPostprocNode : public PostprocNode { public: /*! @@ -66,22 +63,22 @@ class PyPostprocNode : public PostprocNode { */ using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! - * \brief Apply a post processing to the given schedule. + * \brief Apply a postprocessor to the given schedule. * \param sch The schedule to be post processed. - * \return Whether the post processing was successfully applied. + * \return Whether the postprocessor was successfully applied. */ using FApply = runtime::TypedPackedFunc; /*! - * \brief Get the post processing function as string with name. - * \return The string of the post processing function. + * \brief Get the postprocessor function as string with name. + * \return The string of the postprocessor function. */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` funcion. */ + /*! \brief The packed function to the `Apply` function. */ FApply f_apply; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { @@ -112,15 +109,31 @@ class PyPostprocNode : public PostprocNode { class Postproc : public runtime::ObjectRef { public: /*! - * \brief Create a post processing with customized methods on the python-side. + * \brief Create a postprocessor with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. - * \return The post processing created. + * \return The postprocessor created. */ TVM_DLL static Postproc PyPostproc( PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyPostprocNode::FApply f_apply, // PyPostprocNode::FAsString f_as_string); + /*! + * \brief Create a postprocessor that rewrites the cooperative fetch annotation to + * actual vectorized cooperative fetching in loop bindings. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteCooperativeFetch(); + /*! + * \brief Create a postprocessor that rewrites reduction block by moving the init block out. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteReductionBlock(); + /*! + * \brief Create a postprocessor that adds thread binding to unbound blocks + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteUnboundBlock(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 53e05a0b98..68c3c3b1a8 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -44,7 +44,7 @@ class TuneContextNode : public runtime::Object { Optional search_strategy; /*! \brief The schedule rules. */ Optional> sch_rules; - /*! \brief The post processings. */ + /*! \brief The postprocessors. */ Optional> postprocs; /*! \brief The mutators. */ Optional> mutators; @@ -95,7 +95,7 @@ class TuneContext : public runtime::ObjectRef { * \param space_generator The design space generator. * \param search_strategy The search strategy. * \param sch_rules The schedule rules. - * \param postprocs The post processings. + * \param postprocs The postprocessors. * \param mutators The mutators. * \param task_name The name of the tuning task. * \param rand_state The random state. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 170fc8662e..2a748a4653 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1359,7 +1359,7 @@ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; * \brief Mark that the loop should be further skip and bound to environment threads to enable * cooperative fetching. */ -constexpr const char* meta_schedule_lazy_cooperative_fetch = "meta_schedule.lazy_cooperative_fetch"; +constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"; /*! * \brief Mark a block as generated by cache_read or cache_write block. diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 5316eb4663..440312812e 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -14,10 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -The tvm.meta_schedule.postproc package. -Meta Schedule post processings that deal with the problem of -undertermined schedule validity after applying some schedule -primitves at runtime. -""" +"""The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc +from .rewrite_cooperative_fetch import RewriteCooperativeFetch +from .rewrite_reduction_block import RewriteReductionBlock +from .rewrite_unbound_block import RewriteUnboundBlock diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index e05cc9a527..8e3b332c77 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -31,29 +31,22 @@ @register_object("meta_schedule.Postproc") class Postproc(Object): - """Rules to apply a post processing to a schedule. - - Note - ---- - Post processing is designed to deal with the problem of undertermined schedule validity after - applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the maximum - number below 1024, X is only decided at runtime. - """ + """Rules to apply a postprocessor to a schedule.""" def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - """Initialize the post processing with a tune context. + """Initialize the postprocessor with a tune context. Parameters ---------- tune_context : TuneContext - The tuning context for initializing the post processing. + The tuning context for initializing the postprocessor. """ _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member self, tune_context ) def apply(self, sch: Schedule) -> bool: - """Apply a post processing to the given schedule. + """Apply a postprocessor to the given schedule. Parameters ---------- @@ -63,9 +56,9 @@ def apply(self, sch: Schedule) -> bool: Returns ------- result : bool - Whether the post processing was successfully applied. + Whether the postprocessor was successfully applied. """ - return _ffi_api.PostprocApply(self, sch) + return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member @register_object("meta_schedule.PyPostproc") diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py new file mode 100644 index 0000000000..e2d7c22123 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites the cooperative fetch annotation to actual +vectorized cooperative fetching in loop bindings.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteCooperativeFetch") +class RewriteCooperativeFetch(Postproc): + """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized + cooperative fetching in loop bindings. + """ + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py new file mode 100644 index 0000000000..7e15ed493c --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites reduction block by moving the init block out.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteReductionBlock") +class RewriteReductionBlock(Postproc): + """A postprocessor that rewrites reduction block by moving the init block out.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py new file mode 100644 index 0000000000..f4113e5173 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that adds thread binding to unbound blocks""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteUnboundBlock") +class RewriteUnboundBlock(Postproc): + """A postprocessor that adds thread binding to unbound blocks""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index abf43bfa8b..5d242290b1 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -57,7 +57,7 @@ class TuneContext(Object): sch_rules: Optional[List[ScheduleRule]] = None, The schedule rules. postproc: Optional[List[Postproc"]] = None, - The post processings. + The postprocessors. mutator: Optional[List[Mutator]] = None, The mutators. task_name : Optional[str] = None @@ -115,7 +115,7 @@ def __init__( sch_rules : List[ScheduleRule] = [] The schedule rules. postproc : List[Postproc] = [] - The post-processors. + The postprocessors. mutator : List[Mutator] = [] The mutators. task_name : Optional[str] = None diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc new file mode 100644 index 0000000000..e256f4d0cd --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Parse instruction: sch.bind(..., "threadIdx.x") + * \param sch The schedule + * \param inst The instruction to be parsed + * \return NullOpt if parsing fails; Otherwise, the extent of thread axis + */ +Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst) { + static InstructionKind inst_kind_bind = InstructionKind::Get("Bind"); + if (!inst->kind.same_as(inst_kind_bind)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 1); + ICHECK_EQ(inst->attrs.size(), 1); + String thread_axis = Downcast(inst->attrs[0]); + if (thread_axis != "threadIdx.x") { + return NullOpt; + } + return Downcast(sch->Get(Downcast(inst->inputs[0]))->extent); +} + +/*! + * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch) + * \param sch The schedule + * \param inst The instruction to be parsed + * \param vector_lane The length of vector lane in vectorized cooperative fetching + * \return NullOpt if parsing fails; Otherwise, the annotated block + */ +Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) { + static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_kind_annotate)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 2); + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + if (ann_key != attr::meta_schedule_cooperative_fetch) { + return NullOpt; + } + *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; + return Downcast(inst->inputs[0]); +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching + * in loop bindings. + */ +class RewriteCooperativeFetchNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); +}; + +bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { + using tir::BlockRV; + using tir::Instruction; + using tir::LoopRV; + using tir::Schedule; + using tir::Trace; + Trace trace = sch->trace().value(); + int thread_extent = -1; + int vector_lane = -1; + std::vector> tasks; + for (const Instruction& inst : trace->insts) { + if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst)) { + thread_extent = new_thread_extent.value()->value; + } + if (Optional block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) { + ICHECK_NE(thread_extent, -1); + if (vector_lane > 1) { + tasks.push_back([thread_extent, vector_lane, sch, block = block_rv.value()]() -> void { + LoopRV fused = sch->GetLoops(block).back(); + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent), // + Integer(vector_lane)}); + sch->Vectorize(split[2]); + sch->Bind(split[1], "threadIdx.x"); + }); + } else { + tasks.push_back([thread_extent, sch, block = block_rv.value()]() -> void { + LoopRV fused = sch->GetLoops(block).back(); + Array split = sch->Split(fused, {NullOpt, Integer(thread_extent)}); + sch->Bind(split[1], "threadIdx.x"); + }); + } + } + } + for (auto&& task : tasks) { + task(); + } + return true; +} + +Postproc Postproc::RewriteCooperativeFetch() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") + .set_body_typed(Postproc::RewriteCooperativeFetch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc new file mode 100644 index 0000000000..d1a5492361 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The visitor that finds all the reduction block to be decomposed */ +struct ReductionBlockFinder : private StmtVisitor { + public: + /*! \brief Find all the reduction blocks that should be decomposed */ + static std::vector> Find(const ScheduleState& self) { + std::vector> results; + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + ReductionBlockFinder finder; + finder(prim_func->body); + for (const BlockNode* block : finder.results_) { + results.emplace_back(self->stmt2ref.at(block), g_var->name_hint); + } + } + } + return results; + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) { + thread_bound_loop_vars_.insert(loop->loop_var.get()); + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) { + results_.push_back(realize->block.get()); + } + StmtVisitor::VisitStmt_(realize); + } + + bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const { + if (thread_bound_loop_vars_.empty()) { + return true; + } + auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; + const BlockNode* block = realize->block.get(); + int n = block->iter_vars.size(); + for (int i = 0; i < n; ++i) { + IterVar iter_var = block->iter_vars[i]; + PrimExpr binding = realize->iter_values[i]; + if (iter_var->iter_type == tir::kCommReduce) { + if (UsesVar(binding, f_find)) { + return false; + } + } + } + return true; + } + + /*! \brief The results of the collection */ + std::vector results_; + /*! \brief Loop variables that are bound to threads */ + std::unordered_set thread_bound_loop_vars_; +}; + +/*! + * \brief Find the innermost loop that could be decomposed to + * \param block_sref The block to be decomposed + * \return The index of the innermost loop that could be decomposed + */ +int FindDecomposePoint(const StmtSRef& block_sref) { + Array loop_srefs = GetLoops(block_sref); + int n = loop_srefs.size(); + for (int i = 0; i < n; ++i) { + if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { + return i; + } + } + return -1; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Rewrite reduction block by moving the init block out */ +class RewriteReductionBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); +}; + +bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { + for (;;) { + std::vector> results = + tir::ReductionBlockFinder::Find(sch->state()); + int rewritten = 0; + for (const auto& kv : results) { + const tir::StmtSRef& block_sref = kv.first; + const String& global_var_name = kv.second; + int decompose_point = tir::FindDecomposePoint(block_sref); + if (decompose_point == -1) { + continue; + } + tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + ++rewritten; + } + if (rewritten == 0) { + break; + } + } + return true; +} + +Postproc Postproc::RewriteReductionBlock() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") + .set_body_typed(Postproc::RewriteReductionBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc new file mode 100644 index 0000000000..2608dce19a --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The rewrite type for an unbound block */ +enum class BindType : int32_t { + /*! \brief No additional thread binding is needed */ + kNoBind = 0, + /*! \brief Need to bind to blockIdx */ + kBindBlock = 1, + /*! \brief Need to bind to both blockIdx and threadIdx */ + kBindBlockThread = 2, +}; + +/*! + * \brief Check the combination of bindings to be added to the block + * \param block_sref The block to be checked + * \param fuse_first_num The number of loops to be fused + * \return The type of binding to be added to the block + */ +BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) { + Array loops = tir::GetLoops(block_sref); + int n = loops.size(); + if (n == 0) { + return BindType::kNoBind; + } + int i_block_idx = -1; + int i_thread_idx = -1; + int i_multi_child = -1; + for (int i = 0; i < n; ++i) { + const StmtSRef& loop_sref = loops[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + if (i_block_idx == -1) { + i_block_idx = i; + } + } + if (IsThreadIdx(thread_scope)) { + if (i_thread_idx == -1) { + i_thread_idx = i; + } + } + if (!IsSingleStmt(loop->body)) { + if (i_multi_child == -1) { + i_multi_child = i + 1; + } + } + } + if (i_multi_child == -1) { + i_multi_child = n; + } + if (i_block_idx != -1 && i_thread_idx != -1) { + return BindType::kNoBind; + } else if (i_block_idx != -1 && i_thread_idx == -1) { + ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + throw; + } else if (i_block_idx == -1 && i_thread_idx != -1) { + *fuse_first_num = std::min(i_multi_child, i_thread_idx); + return BindType::kBindBlock; + } else { // i_block_idx == -1 && i_thread_idx == -1 + *fuse_first_num = i_multi_child; + return BindType::kBindBlockThread; + } +} + +/*! \brief Find all the blocks that are not bound */ +class UnboundBlockFinder : private StmtVisitor { + public: + static std::vector> Find(const ScheduleState& self) { + UnboundBlockFinder finder(self); + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + finder.global_var_name_ = g_var->name_hint; + finder(Downcast(prim_func->body)->block->body); + } + } + return std::move(finder.blocks_); + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + ++n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + ++n_thread_idx_; + } + if (n_block_idx_ == 0 || n_thread_idx_ == 0) { + StmtVisitor::VisitStmt_(loop); + } + if (IsBlockIdx(thread_scope)) { + --n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + --n_thread_idx_; + } + } + + void VisitStmt_(const BlockNode* block) final { + blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_); + } + + explicit UnboundBlockFinder(const ScheduleState& self) + : self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {} + + /*! \brief The schedule state */ + const ScheduleState& self_; + /*! \brief The list of unbound blocks */ + std::vector> blocks_; + /*! \brief The number of blockIdx above the current stmt */ + int n_block_idx_; + /*! \brief The number of threadIdx above the current stmt */ + int n_thread_idx_; + /*! \brief The name of the global var */ + String global_var_name_; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Add thread binding to unbound blocks */ +class RewriteUnboundBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context->target.defined()) << "ValueError: target is not defined"; + Optional warp_size = context->target.value()->GetAttr("thread_warp_size"); + CHECK(warp_size.defined()) << "ValueError: missing attribute `thread_warp_size` in the target"; + this->warp_size_ = warp_size.value(); + } + + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + public: + /*! \brief The cached warp size from Target */ + int warp_size_ = -1; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `warp_size_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode); +}; + +bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { + using tir::BlockRV; + using tir::LoopRV; + using tir::Schedule; + ICHECK_NE(this->warp_size_, -1); + std::vector> unbound_blocks = + tir::UnboundBlockFinder::Find(sch->state()); + for (const auto& kv : unbound_blocks) { + tir::StmtSRef block_sref = kv.first; + String global_var_name = kv.second; + int fuse_first_num = 0; + tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num); + if (bind_type == tir::BindType::kNoBind) { + continue; + } + BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num}); + if (bind_type == tir::BindType::kBindBlock) { + sch->Bind(fused, "blockIdx.x"); + } else if (bind_type == tir::BindType::kBindBlockThread) { + Array splits = sch->Split(fused, {NullOpt, Integer(this->warp_size_)}); + ICHECK_EQ(splits.size(), 2); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + } + } + return true; +} + +Postproc Postproc::RewriteUnboundBlock() { + ObjectPtr n = make_object(); + n->warp_size_ = -1; + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") + .set_body_typed(Postproc::RewriteUnboundBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 711401591c..ae8fa1f73c 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -21,9 +21,6 @@ namespace tvm { namespace meta_schedule { -using tir::BlockRV; -using tir::Schedule; - /*! \brief The type of inline to be performed on a specific block */ enum class InlineType : int32_t { /*! \brief No inline opportunity */ @@ -38,13 +35,13 @@ enum class InlineType : int32_t { class AutoInlineNode : public ScheduleRuleNode { public: /*! \brief Checks if the specific block should be inlined */ - inline InlineType CheckInline(const Schedule& sch, const BlockRV& block_rv); + inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv); // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - Array Apply(const Schedule& sch, const BlockRV& block_rv) final { + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { InlineType inline_type = CheckInline(sch, block_rv); if (inline_type == InlineType::kInlineIntoConsumer) { sch->ComputeInline(block_rv); @@ -87,7 +84,8 @@ class AutoInlineNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); }; -inline InlineType AutoInlineNode::CheckInline(const Schedule& sch, const BlockRV& block_rv) { +inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { using namespace tvm::tir; StmtSRef block_sref = sch->GetSRef(block_rv); ScheduleState state = sch->state(); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 62884b6bc9..a74c2e05cf 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -333,18 +333,14 @@ inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const Array buffer_loops = sch->GetLoops(cache_read_block); LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // buffer_loops.end()}); - // Do cooperative fetching + // Annotate cooperative fetching if (vector_load_max_len > 0) { // cooperative fetch + vectorized loading - // Split into inner and outer + // Split into inner and outer, vectorize the inner loop Array factors = sch->SamplePerfectTile(fused, 2, vector_load_max_len); - Array splits = sch->Split(fused, {factors[0], factors[1]}); - // Vectorize the inner loop - sch->Vectorize(splits[1]); - fused = splits[0]; + // Add cooperative fetching + sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, factors[1]); } - // Add cooperative fetching - sch->Annotate(fused, tir::attr::meta_schedule_lazy_cooperative_fetch, Integer(1)); } State new_state = state; new_state.sch = sch; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index d8e96d0156..df335c4a15 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -222,6 +222,19 @@ inline IRModule DeepCopyIRModule(IRModule mod) { return Downcast(LoadJSON(SaveJSON(mod))); } +/*! + * \brief Get the BlockRV from a block StmtSRef + * \param sch The schedule + * \param block_sref The block StmtSRef + * \param globla_var_name The global variable name + * \return The BlockRV + */ +inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, + const String& global_var_name) { + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return sch->GetBlock(block->name_hint, global_var_name); +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index d8327df295..c0ead80a4a 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -292,6 +292,13 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to + * \param loop_sref The loop to be checked + * \return The IterVarType of the specific loop + */ +IterVarType GetLoopIterType(const StmtSRef& loop_sref); + /******** Producer-consumer relation ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 8d960d68ee..b69bfd730b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -727,6 +727,53 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +IterVarType GetLoopIterType(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const Var& loop_var = loop->loop_var; + int n_spatial = 0; + int n_reduce = 0; + int n_other = 0; + auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { + if (const auto* realize = obj.as()) { + const BlockNode* block = realize->block.get(); + // Number of block vars and their bindings + ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); + int n = realize->iter_values.size(); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + // Categorize the current block var + int* ref = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + ref = &n_spatial; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + ref = &n_reduce; + } else { + ref = &n_other; + } + // Visit the binding to see if `loop_var` appears + PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void { + if (obj.same_as(loop_var)) { + (*ref) += 1; + } + }); + } + return false; + } + return true; + }; + PreOrderVisit(loop->body, f_visit); + if (n_other) { + return IterVarType::kOpaque; + } else if (n_spatial && n_reduce) { + return IterVarType::kOpaque; + } else if (n_reduce) { + return IterVarType::kCommReduce; + } else { + return IterVarType::kDataPar; + } +} + /******** Producer-consumer relation ********/ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { @@ -1368,11 +1415,11 @@ bool HasOp(const Stmt& stmt, const Array& ops) { op_set.insert(op.operator->()); } bool found = false; - tir::PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { + PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { if (found) { return false; } - if (const auto* call = obj.as()) { + if (const auto* call = obj.as()) { if (op_set.count(call->op.operator->())) { found = true; } @@ -1389,15 +1436,15 @@ bool HasIfThenElse(const Stmt& stmt) { // stop visiting return false; } - if (const auto* realize = obj.as()) { + if (const auto* realize = obj.as()) { // Case 1: BlockRealize if (!is_one(realize->predicate)) { has_branch = true; } - } else if (obj->IsInstance() || obj->IsInstance()) { + } else if (obj->IsInstance() || obj->IsInstance()) { // Case 2: IfThenElse / Select has_branch = true; - } else if (const auto* call = obj.as()) { + } else if (const auto* call = obj.as()) { // Case 3: Call static const Op& op_if_then_else = Op::Get("tir.if_then_else"); if (call->op.same_as(op_if_then_else)) { @@ -1406,7 +1453,7 @@ bool HasIfThenElse(const Stmt& stmt) { } return !has_branch; }; - tir::PreOrderVisit(stmt, f_visit); + PreOrderVisit(stmt, f_visit); return has_branch; } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 614b0d0640..a9d773756a 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -594,8 +594,8 @@ void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key if (const auto* str = ann_val.as()) { tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, GetRef(str)); } else if (const auto* expr = ann_val.as()) { - ICHECK(ann_val.as() == nullptr) - << "TypeError: runtime::String is expected, but gets tir::StringImm"; + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->Get(GetRef(expr))); } else { LOG(FATAL) @@ -620,8 +620,8 @@ void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_k if (const auto* str = ann_val.as()) { tir::Annotate(state_, this->GetSRef(block_rv), ann_key, GetRef(str)); } else if (const auto* expr = ann_val.as()) { - ICHECK(ann_val.as() == nullptr) - << "TypeError: runtime::String is expected, but gets tir::StringImm"; + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; tir::Annotate(state_, this->GetSRef(block_rv), ann_key, this->Get(GetRef(expr))); } else { LOG(FATAL) diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 95d636467a..f842f75763 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -43,7 +43,7 @@ namespace tir { * * // Convertible to `InstructionKindNode::FInstructionApply` * static Array ApplyToSchedule( - * const tir::Schedule& sch, + * const Schedule& sch, * const Array& inputs, * const Array& attrs, * const Optional& decision); diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 9412222a5b..09b7a47e8e 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -58,9 +58,9 @@ void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) { // Extract annotation const Map* annotations = nullptr; - if (const auto* loop = sref->StmtAs()) { + if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; - } else if (const auto* block = sref->StmtAs()) { + } else if (const auto* block = sref->StmtAs()) { annotations = &block->annotations; } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); @@ -71,15 +71,15 @@ void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) Map new_ann(*annotations); new_ann.erase(ann_key); // Create the new stmt - if (const auto* loop = sref->StmtAs()) { - ObjectPtr n = make_object(*loop); + if (const auto* loop = sref->StmtAs()) { + ObjectPtr n = make_object(*loop); n->annotations = std::move(new_ann); - self->Replace(sref, tir::For(n), {}); - } else if (const auto* block = sref->StmtAs()) { - ObjectPtr n = make_object(*block); + self->Replace(sref, For(n), {}); + } else if (const auto* block = sref->StmtAs()) { + ObjectPtr n = make_object(*block); n->annotations = std::move(new_ann); - tir::Block p(n); - self->Replace(sref, p, {{GetRef(block), p}}); + Block p(n); + self->Replace(sref, p, {{GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc index cb693c77cd..2656fe7ba9 100644 --- a/src/tir/schedule/primitive/read_write_at.cc +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -20,10 +20,6 @@ #include #include "../utils.h" -#include "tvm/runtime/memory.h" -#include "tvm/runtime/object.h" -#include "tvm/tir/schedule/block_scope.h" -#include "tvm/tir/stmt_functor.h" namespace tvm { namespace tir { diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 0312b924fc..b3ae553cf5 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -20,7 +20,6 @@ #include #include "../utils.h" -#include "tvm/support/random_engine.h" namespace tvm { namespace tir { @@ -298,7 +297,7 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + const StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); const int64_t* extent = GetLoopIntExtent(loop); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index faeb0b9907..1be5ed06ac 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -897,7 +897,7 @@ class ChildReplacer : private StmtMutator { int seq_index_; }; -void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, +void ScheduleStateNode::Replace(const StmtSRef& _src_sref, const Stmt& tgt_stmt, const Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 71a8bdf5b8..300d5ae96c 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -178,6 +178,18 @@ inline Array AsArray(const Stmt& stmt) { return {stmt}; } +/*! + * \brief Checks of a statement is a SeqStmt that contains multiple statements + * \param stmt The statement to be checked + * \return A boolean indicating the result + */ +inline bool IsSingleStmt(const Stmt& stmt) { + if (const auto* seq_stmt = stmt.as()) { + return seq_stmt->seq.size() == 1; + } + return true; +} + /******** IterVar ********/ /*! @@ -192,6 +204,36 @@ inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_va Var(std::move(name), loop->loop_var.dtype()), iter_var_type); } +/*! + * \brief Get the thread scope bound to the specific loop + * \param loop The loop to be inspected + * \return The thread scope bound to the loop + */ +inline runtime::ThreadScope GetThreadScope(const ForNode* loop) { + if (loop->kind == ForKind::kThreadBinding) { + return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); + } + return runtime::ThreadScope{-1, -1}; +} + +/*! + * \brief Check if the thread scope is blockIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is blockIdx + */ +inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 0; // The rank of blockIdx is 0 +} + +/*! + * \brief Check if the thread scope is threadIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is threadIdx + */ +inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 1 && thread_scope.dim_index >= 0; +} + /******** Integer set ********/ /*! diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py index 52f07fdff0..7a448ec09f 100644 --- a/tests/python/unittest/test_meta_schedule_postproc.py +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -79,7 +79,7 @@ def apply(self, sch: Schedule) -> bool: assert postproc.apply(sch) try: tvm.ir.assert_structural_equal(sch.mod, mod) - raise Exception("The post processing did not change the schedule.") + raise Exception("The postprocessors did not change the schedule.") except (ValueError): _check_correct(sch) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py new file mode 100644 index 0000000000..70efb402c3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteCooperativeFetch +from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postproc=[ + RewriteCooperativeFetch(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class AfterRewrite: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + target = _target() + ctx = _create_context(mod, target) + + sch = tir.Schedule(mod, debug_mask="all") + # fmt: off + # pylint: disable=line-too-long,invalid-name + b0 = sch.get_block(name="C", func_name="main") + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) + l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9]) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2]) + l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19]) + v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32]) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27]) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22) + sch.bind(loop=l33, thread_axis="threadIdx.x") + b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1) + _, _, _, _, l39, l40 = sch.get_loops(block=b34) + l41 = sch.fuse(l39, l40) + _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1]) + sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1) + _, _, _, _, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50) + _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2]) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53) + sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1) + # pylint: enable=line-too-long,invalid-name + # fmt: on + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, AfterRewrite) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py new file mode 100644 index 0000000000..ef654d0874 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteReductionBlock +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postproc=[ + RewriteReductionBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(2, 2, 16, 2): + with T.block("C_init"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3_init * 16 + i0_4_init) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3_init * 2 + i1_4_init) + T.reads([]) + T.writes([C_local[i, j]]) + C_local[i, j] = T.float32(0) + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C_update"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_reduction_block(): + mod = Before + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After) + + +if __name__ == "__main__": + test_rewrite_reduction_block() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py new file mode 100644 index 0000000000..efe0a41172 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteUnboundBlock +from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postproc=[ + RewriteUnboundBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i, j in T.grid(512, 512): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i_j_fused_0 in T.thread_binding(0, 8192, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("C"): + vi = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) // 512) + vj = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) % 512) + B[vi, vj] = A[vi, vj] + 1.0 + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = Before + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index a05ddaf568..0273a2bdf6 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -142,7 +142,7 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): - for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): with T.block("pad_temp_shared"): v0 = T.axis.spatial(1, 0) @@ -153,7 +153,7 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand T.writes([pad_temp_shared[v0, v1, v2, v3]]) T.block_attr({"meta_schedule.cache_type":0}) pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): with T.block("W_shared"): v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) @@ -212,7 +212,7 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): - for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): with T.block("pad_temp_shared"): v0 = T.axis.spatial(1, 0) @@ -223,7 +223,7 @@ def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.hand T.writes([pad_temp_shared[v0, v1, v2, v3]]) T.block_attr({"meta_schedule.cache_type":0}) pad_temp_shared[v0, v1, v2, v3] = T.if_then_else(v2 >= 1 and v2 < 57 and v3 >= 1 and v3 < 57, X[v0, v1, v2 - 1, v3 - 1], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.lazy_cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): with T.block("W_shared"): v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index 03f35749f0..240e4eb86f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -180,17 +180,13 @@ def test_cuda_matmul(): "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", - "sch.vectorize(loop=l45)", - 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', - 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", - "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", - "l53 = sch.fuse(l51, l52)", - "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", - "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", - "sch.vectorize(loop=l57)", - 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", ] ] @@ -237,17 +233,13 @@ def test_cuda_matmul_relu(): "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", - "sch.vectorize(loop=l45)", - 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', - 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", - "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", - "l53 = sch.fuse(l51, l52)", - "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", - "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", - "sch.vectorize(loop=l57)", - 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", ] ] diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py index 0e6be26fa0..b691c23437 100644 --- a/tests/python/unittest/test_meta_schedule_sketch_cuda.py +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -51,17 +51,13 @@ def test_meta_schedule_cuda_sketch_matmul(): "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", - "sch.vectorize(loop=l45)", - 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', - 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", - "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", - "l53 = sch.fuse(l51, l52)", - "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", - "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", - "sch.vectorize(loop=l57)", - 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", ] ] @@ -106,20 +102,16 @@ def test_meta_schedule_cuda_sketch_matmul_relu(): "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", - "sch.vectorize(loop=l45)", - 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', - 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", - "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", - "l53 = sch.fuse(l51, l52)", - "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", - "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", - "sch.vectorize(loop=l57)", - 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", - 'b58 = sch.get_block(name="compute", func_name="main")', - "sch.reverse_compute_inline(block=b58)", + 'b54 = sch.get_block(name="compute", func_name="main")', + "sch.reverse_compute_inline(block=b54)", ] ] # pylint: enable=line-too-long @@ -171,20 +163,16 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw(): "l71, l72, l73, l74, l75, l76, l77, l78, l79, l80 = sch.get_loops(block=b70)", "l81 = sch.fuse(l77, l78, l79, l80)", "v82, v83 = sch.sample_perfect_tile(loop=l81, n=2, max_innermost_factor=4)", - "l84, l85 = sch.split(loop=l81, factors=[v82, v83])", - "sch.vectorize(loop=l85)", - 'sch.annotate(block_or_loop=l84, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', - 'b86 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b86, loop=l64, preserve_unit_loops=1)", - "l87, l88, l89, l90, l91, l92, l93, l94, l95, l96 = sch.get_loops(block=b86)", - "l97 = sch.fuse(l93, l94, l95, l96)", - "v98, v99 = sch.sample_perfect_tile(loop=l97, n=2, max_innermost_factor=4)", - "l100, l101 = sch.split(loop=l97, factors=[v98, v99])", - "sch.vectorize(loop=l101)", - 'sch.annotate(block_or_loop=l100, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v83)', + 'b84 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b84, loop=l64, preserve_unit_loops=1)", + "l85, l86, l87, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84)", + "l95 = sch.fuse(l91, l92, l93, l94)", + "v96, v97 = sch.sample_perfect_tile(loop=l95, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b84, ann_key="meta_schedule.cooperative_fetch", ann_val=v97)', "sch.reverse_compute_at(block=b1, loop=l69, preserve_unit_loops=1)", - 'b102 = sch.get_block(name="pad_temp", func_name="main")', - "sch.compute_inline(block=b102)", + 'b98 = sch.get_block(name="pad_temp", func_name="main")', + "sch.compute_inline(block=b98)", ] ] # pylint: enable=line-too-long @@ -249,22 +237,18 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl "l74, l75, l76, l77, l78, l79, l80, l81, l82, l83 = sch.get_loops(block=b73)", "l84 = sch.fuse(l80, l81, l82, l83)", "v85, v86 = sch.sample_perfect_tile(loop=l84, n=2, max_innermost_factor=4)", - "l87, l88 = sch.split(loop=l84, factors=[v85, v86])", - "sch.vectorize(loop=l88)", - 'sch.annotate(block_or_loop=l87, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', - 'b89 = sch.cache_read(block=b3, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b89, loop=l67, preserve_unit_loops=1)", - "l90, l91, l92, l93, l94, l95, l96, l97, l98, l99 = sch.get_loops(block=b89)", - "l100 = sch.fuse(l96, l97, l98, l99)", - "v101, v102 = sch.sample_perfect_tile(loop=l100, n=2, max_innermost_factor=4)", - "l103, l104 = sch.split(loop=l100, factors=[v101, v102])", - "sch.vectorize(loop=l104)", - 'sch.annotate(block_or_loop=l103, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v86)', + 'b87 = sch.cache_read(block=b3, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b87, loop=l67, preserve_unit_loops=1)", + "l88, l89, l90, l91, l92, l93, l94, l95, l96, l97 = sch.get_loops(block=b87)", + "l98 = sch.fuse(l94, l95, l96, l97)", + "v99, v100 = sch.sample_perfect_tile(loop=l98, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b87, ann_key="meta_schedule.cooperative_fetch", ann_val=v100)', "sch.reverse_compute_at(block=b4, loop=l72, preserve_unit_loops=1)", - 'b105 = sch.get_block(name="pad_temp", func_name="main")', - 'b106 = sch.get_block(name="compute_1", func_name="main")', - "sch.reverse_compute_inline(block=b106)", - "sch.compute_inline(block=b105)", + 'b101 = sch.get_block(name="pad_temp", func_name="main")', + 'b102 = sch.get_block(name="compute_1", func_name="main")', + "sch.reverse_compute_inline(block=b102)", + "sch.compute_inline(block=b101)", ] ] # pylint: enable=line-too-long @@ -284,7 +268,6 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl ), target=_target(), ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 check_trace(spaces, expected)