diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index b41bdfec06a7..eaab8c7bd484 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -18,4 +18,5 @@ from .postproc import Postproc, PyPostproc from .disallow_dynamic_loop import DisallowDynamicLoop from .rewrite_reduction_block import RewriteReductionBlock +from .rewrite_unbound_block import RewriteUnboundBlock from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py new file mode 100644 index 000000000000..f4113e5173c9 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that adds thread binding to unbound blocks""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteUnboundBlock") +class RewriteUnboundBlock(Postproc): + """A postprocessor that adds thread binding to unbound blocks""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc new file mode 100644 index 000000000000..624e6d27e844 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The rewrite type for an unbound block */ +enum class BindType : int32_t { + /*! \brief No additional thread binding is needed */ + kNoBind = 0, + /*! \brief Need to bind to blockIdx */ + kBindBlock = 1, + /*! \brief Need to bind to both blockIdx and threadIdx */ + kBindBlockThread = 2, +}; + +/*! + * \brief Check the combination of bindings to be added to the block + * \param block_sref The block to be checked + * \param fuse_first_num The number of loops to be fused + * \return The type of binding to be added to the block + */ +BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) { + Array loops = tir::GetLoops(block_sref); + int n = loops.size(); + if (n == 0) { + return BindType::kNoBind; + } + int i_block_idx = -1; + int i_thread_idx = -1; + int i_multi_child = -1; + int i_spatial_loop = -1; + for (int i = 0; i < n; ++i) { + const StmtSRef& loop_sref = loops[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + if (i_block_idx == -1) { + i_block_idx = i; + } + } + if (IsThreadIdx(thread_scope)) { + if (i_thread_idx == -1) { + i_thread_idx = i; + } + } + if (!IsSingleStmt(loop->body)) { + if (i_multi_child == -1) { + i_multi_child = i + 1; + } + } + if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) { + if (i_spatial_loop == i - 1) { + ++i_spatial_loop; + } + } + } + if (i_multi_child == -1) { + i_multi_child = n; + } + if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) { + return BindType::kNoBind; + } else if (i_block_idx != -1 && i_thread_idx == -1) { + ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + throw; + } else if (i_block_idx == -1 && i_thread_idx != -1) { + *fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); + return BindType::kBindBlock; + } else { // i_block_idx == -1 && i_thread_idx == -1 + *fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1); + return BindType::kBindBlockThread; + } +} + +/*! \brief Find all the blocks that are not bound */ +class UnboundBlockFinder : private StmtVisitor { + public: + static std::vector> Find(const ScheduleState& self) { + UnboundBlockFinder finder(self); + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + finder.global_var_name_ = g_var->name_hint; + finder(Downcast(prim_func->body)->block->body); + } + } + return std::move(finder.blocks_); + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + ++n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + ++n_thread_idx_; + } + if (n_block_idx_ == 0 || n_thread_idx_ == 0) { + StmtVisitor::VisitStmt_(loop); + } + if (IsBlockIdx(thread_scope)) { + --n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + --n_thread_idx_; + } + } + + void VisitStmt_(const BlockNode* block) final { + blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_); + } + + explicit UnboundBlockFinder(const ScheduleState& self) + : self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {} + + /*! \brief The schedule state */ + const ScheduleState& self_; + /*! \brief The list of unbound blocks */ + std::vector> blocks_; + /*! \brief The number of blockIdx above the current stmt */ + int n_block_idx_; + /*! \brief The number of threadIdx above the current stmt */ + int n_thread_idx_; + /*! \brief The name of the global var */ + String global_var_name_; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Add thread binding to unbound blocks */ +class RewriteUnboundBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context->target.defined()) << "ValueError: target is not defined"; + Optional warp_size = context->target.value()->GetAttr("thread_warp_size"); + CHECK(warp_size.defined()) << "ValueError: missing attribute `thread_warp_size` in the target"; + this->warp_size_ = warp_size.value(); + } + + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + public: + /*! \brief The cached warp size from Target */ + int warp_size_ = -1; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `warp_size_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode); +}; + +bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { + using tir::BlockRV; + using tir::LoopRV; + using tir::Schedule; + ICHECK_NE(this->warp_size_, -1); + std::vector> unbound_blocks = + tir::UnboundBlockFinder::Find(sch->state()); + for (const auto& kv : unbound_blocks) { + tir::StmtSRef block_sref = kv.first; + String global_var_name = kv.second; + int fuse_first_num = 0; + tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num); + if (bind_type == tir::BindType::kNoBind) { + continue; + } + BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num}); + if (bind_type == tir::BindType::kBindBlock) { + sch->Bind(fused, "blockIdx.x"); + } else if (bind_type == tir::BindType::kBindBlockThread) { + Array splits = sch->Split(fused, {NullOpt, Integer(this->warp_size_)}); + ICHECK_EQ(splits.size(), 2); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + } + } + return true; +} + +Postproc Postproc::RewriteUnboundBlock() { + ObjectPtr n = make_object(); + n->warp_size_ = -1; + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") + .set_body_typed(Postproc::RewriteUnboundBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index bb34c6aadaba..673813b0f140 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -193,6 +193,18 @@ inline Array AsArray(const Stmt& stmt) { return {stmt}; } +/*! + * \brief Checks of a statement is a SeqStmt that contains multiple statements + * \param stmt The statement to be checked + * \return A boolean indicating the result + */ +inline bool IsSingleStmt(const Stmt& stmt) { + if (const auto* seq_stmt = stmt.as()) { + return seq_stmt->seq.size() == 1; + } + return true; +} + /******** IterVar ********/ /*! diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py new file mode 100644 index 000000000000..4ab2741da181 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteUnboundBlock +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteUnboundBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + + +@tvm.script.ir_module +class Before_cooperative_fetch: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i, j in T.grid(512, 512): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.ir_module +class After_cooperative_fetch: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i_j_fused_0 in T.thread_binding(0, 8192, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("C"): + vi = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) // 512) + vj = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) % 512) + B[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.ir_module +class Before_norm_bmn: + @T.prim_func + def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + C = T.alloc_buffer([1], dtype="float32") + for i0, i1, i2 in T.grid(1, 256, 256): + with T.block("C"): + b, i, j = T.axis.remap("SRR", [i0, i1, i2]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0 in T.serial(1): + with T.block("D"): + b = T.axis.S(1, i0) + D[b] = T.sqrt(C[b], dtype="float32") + + +@tvm.script.ir_module +class After_norm_bmn: + @T.prim_func + def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + C = T.alloc_buffer([1], dtype="float32") + for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for i1, i2 in T.grid(256, 256): + with T.block("C"): + b = T.axis.S(1, 0) + i, j = T.axis.remap("RR", [i1, i2]) + T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.S(1, 0) + T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + D[b] = T.sqrt(C[b], dtype="float32") + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = Before_cooperative_fetch + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After_cooperative_fetch) + + +def test_rewrite_norm_bmn(): + mod = Before_norm_bmn + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After_norm_bmn) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch() + test_rewrite_norm_bmn()