From ab025aba1abc60d917522511dfb75da30832054f Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Mon, 20 Dec 2021 00:17:08 +0800 Subject: [PATCH] [MemHammer] Rewrite Rules (#554) * rewrite rules * format * move files * change type for constraint and output * fix * format * address comments * address comments * fix --- include/tvm/arith/int_set.h | 17 + src/arith/int_set.cc | 12 + src/tir/transforms/memhammer_coalesce.cc | 227 +++++++++ .../memhammer_intermediate_stage.cc | 437 ++++++++++++++++++ src/tir/transforms/memhammer_rewrite_rule.h | 188 ++++++++ .../memhammer_tensorcore_rewrite.cc | 259 +++++++++++ 6 files changed, 1140 insertions(+) create mode 100644 src/tir/transforms/memhammer_coalesce.cc create mode 100644 src/tir/transforms/memhammer_intermediate_stage.cc create mode 100644 src/tir/transforms/memhammer_rewrite_rule.h create mode 100644 src/tir/transforms/memhammer_tensorcore_rewrite.cc diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index f55a0651a8..a9c46db7b5 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -164,6 +164,14 @@ Map ConvertDomMap(const std::unordered_map& * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const Map& dom_map); +/*! + * \brief Same as EvalSet, but takes Map + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \return An integer set that can cover all the possible values of e. + */ +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -172,6 +180,15 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map); * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); +/*! + * \brief Same as EvalSet, but takes Array + * + * \param exprs The expressions to be evaluated. + * \param dom_map The domain of each variable. + * \return An array of integer sets that can cover all the possible values. + */ +Array EvalSet(const Array& exprs, const Map& dom_map); + /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index fe3a37f88f..ac414bfcb9 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -762,6 +762,18 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom return EvalSet(e, ConvertDomMap(dom_map)); } +Array EvalSet(const Array& exprs, const Map& dom_map) { + Array result; + result.reserve(exprs.size()); + Analyzer ana; + IntervalSetEvaluator m(&ana, dom_map); + for (const PrimExpr& e : exprs) { + result.push_back(m.Eval(e)); + } + return result; +} + + IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc new file mode 100644 index 0000000000..4a21d1894a --- /dev/null +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../runtime/thread_storage_scope.h" +#include "memhammer_rewrite_rule.h" +namespace tvm { +namespace tir { +/*! + * \brief Fuse consecutive loops + * \param stmt the outer-most loop + * \return the fused loop + */ +Stmt FuseNestLoops(const Stmt& stmt) { + std::vector loops; + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Map subst_map; + PrimExpr tot = fused_var; + for (int i = n - 1; i >= 0; i--) { + subst_map.Set(loops[i]->loop_var, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + auto f_substitute = [&](const Var& v) -> Optional { + return subst_map.Get(v).value_or(v); + }; + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + Stmt new_stmt = Substitute(body, f_substitute); + new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); + return new_stmt; +} + +/*! + * \brief a combination of split, bind, vectorize, + * a helper function to perform coalesced load/store + * \param stmt the stmt to do transformation + * \param constraints The constraints, including thread extents, vector bytes, and data bits. + * \return The stmt after transformation + */ +Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { + Stmt body = stmt; + const ForNode* loop = body.as(); + PrimExpr vector_bytes = constraints.vector_bytes; + PrimExpr threadIdx_x = constraints.thread_extent.Get("threadIdx.x").value_or(Integer(1)); + PrimExpr threadIdx_y = constraints.thread_extent.Get("threadIdx.y").value_or(Integer(1)); + PrimExpr threadIdx_z = constraints.thread_extent.Get("threadIdx.z").value_or(Integer(1)); + PrimExpr tot_threads = threadIdx_x * threadIdx_y * threadIdx_z; + PrimExpr data_bits = constraints.data_bits; + PrimExpr vector_len = max(1, vector_bytes * 8 / data_bits); + if (!loop || !is_zero(indexmod(loop->extent, (vector_len * tot_threads)))) { + LOG(FATAL) << "the number of elements must be a multiple of thread num"; + } + PrimExpr outer_loop_extent = indexdiv(loop->extent, tot_threads * vector_len); + Array factors{outer_loop_extent}; + std::vector thread_axis; + // generate thread binding loops + if (!is_one(threadIdx_z)) { + factors.push_back(threadIdx_z); + thread_axis.push_back("threadIdx.z"); + } + if (!is_one(threadIdx_y)) { + factors.push_back(threadIdx_y); + thread_axis.push_back("threadIdx.y"); + } + if (!is_one(threadIdx_x)) { + factors.push_back(threadIdx_x); + thread_axis.push_back("threadIdx.x"); + } + // generate vectorized loop + factors.push_back(vector_len); + int n = factors.size(); + std::vector new_loop_vars; + new_loop_vars.reserve(n); + for (int i = 0; i < n; i++) { + new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); + } + + PrimExpr substitute_value = 0; + for (int i = 0; i < n; i++) { + substitute_value *= factors[i]; + substitute_value += new_loop_vars[i]; + } + body = Substitute(loop->body, [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }); + + For new_loop = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, body); + + for (int i = n - 2; i >= 1; i--) { + new_loop = + For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(new_loop), + IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + } + + new_loop = For(new_loop_vars[0], 0, outer_loop_extent, ForKind::kSerial, new_loop); + return std::move(new_loop); +} + +Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_fuse = FuseNestLoops(stmt); + Stmt after_split = SplitBindVectorize(after_fuse, constraints); + return after_split; +} + +/*! + * \brief Get the index mapping of a specific stmt. + * The stmt is like: + * for i0: + * ... + * for in: + * A[f(i0, ..., in])] = B[i0, ..., in], + * where f is the index mapping we want to get. + * \param constraints The constraints, including the write region that is required to calculate + * the index mapping + * \return The mapping in the form of j0, ..., jm, where j0, ... jm = f(i0, ..., in) + */ +Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + body = loop->body; + } + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + BufferRegion write_region = constraints.write_region; + const Array& write_index = buf_store->indices; + ICHECK(write_region->region.size() == write_index.size() && + write_region->buffer.same_as(buf_store->buffer)); + Array result; + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(write_region->region.size()); i++) { + PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); + if (!is_zero(pattern)) { + result.push_back(pattern); + } + } + return result; +} + +Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body = stmt; + Map var_range; + Array loop_vars; + // Step 1. Get index mapping + Array mapping_pattern = GetMapping(stmt, constraints); + while (const ForNode* loop = body.as()) { + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + loop_vars.push_back(loop->loop_var); + body = loop->body; + } + // Step 2. Get Inverse mapping + arith::Analyzer analyzer; + Array iter_map = + arith::DetectIterMap(mapping_pattern, var_range, Bool(true), true, &analyzer); + CHECK_EQ(iter_map.size(), loop_vars.size()); + Map inverse_mapping = arith::InverseAffineIterMap(iter_map, loop_vars); + // Step 3. Generate new body + BufferRegion read_region = constraints.read_region; + BufferRegion write_region = constraints.write_region; + Array write_index; + Array read_index; + Array new_loop_vars; + Map substitute_map; + // Step 3.1 construct target buffer indices + for (int i = 0, j = 0; i < static_cast(write_region->region.size()); i++) { + if (is_one(write_region->region[i]->extent)) { + write_index.push_back(write_region->region[i]->min); + } else { + Var var = runtime::Downcast(loop_vars[j]).copy_with_suffix("_inverse"); + new_loop_vars.push_back(var); + substitute_map.Set(runtime::Downcast(loop_vars[j++]), var); + write_index.push_back(write_region->region[i]->min + var); + } + } + // Step 3.2 construct source buffer indices + for (int i = 0, j = 0; i < static_cast(read_region->region.size()); i++) { + if (is_one(read_region->region[i]->extent)) { + read_index.push_back(read_region->region[i]->min); + } else { + read_index.push_back( + read_region->region[i]->min + + Substitute(inverse_mapping[Downcast(loop_vars[j++])], substitute_map)); + } + } + BufferLoad new_buf_load = BufferLoad(read_region->buffer, read_index); + BufferStore new_buf_store = BufferStore(write_region->buffer, new_buf_load, write_index); + Stmt ret = new_buf_store; + // Step 3.3 construct loop body + for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; i--) { + PrimExpr extent = write_region->region[i]->extent; + ret = For(new_loop_vars[i], 0, extent, ForKind::kSerial, std::move(ret)); + } + return ret; +} +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc new file mode 100644 index 0000000000..3503b371bd --- /dev/null +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_body, int ith = -1, + Stmt* ith_loop = nullptr) { + Stmt ret = inner_body; + for (int i = static_cast(loops.size() - 1); i >= 0; i--) { + ObjectPtr new_loop = make_object(*loops[i]); + new_loop->body = ret; + ret = For(new_loop); + if (ith == i) { + *ith_loop = ret; + } + } + return ret; +} + +/*! + * \brief lift all the thread binding loops + * \param stmt the top loop + * \return a pair. The first is the transformed stmt. + * The second is the lowest thread binding loop. + */ +std::pair LiftThreadBindingLoops(Stmt stmt) { + std::vector normal_loops; + std::vector thread_binding_loops; + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + if (loop->kind == ForKind::kThreadBinding) { + thread_binding_loops.push_back(loop); + } else { + normal_loops.push_back(loop); + } + body = loop->body; + } + body = CopyLoopChain(normal_loops, body); + For compute_location; + body = CopyLoopChain(thread_binding_loops, body, + static_cast(thread_binding_loops.size()) - 1, &compute_location); + + return std::make_pair(body, compute_location); +} + +/*! + * \brief Analyze the access pattern for buffer rank promotion. + * Rank promotion is a transformation that reshapes the buffer + * but doesn't change its underlying data layout. + * After the reshape, we expect that all dimensions of the access indices + * will be in the form of floormod(floordiv(x, a), b). + * Rank promotion removes strided access, thus enabling further buffer compacting + */ +class IndexPatternFinder : public ExprVisitor { + public: + IndexPatternFinder(const Map& var_range, Array* resulting_index) + : var_range_(var_range), resulting_index_(resulting_index) {} + + /*! + * \brief Calculate the new buffer shape after rank promotion. + * For each dimension of original shape, it will be split into multiple parts. + * The inner array represents the multiple parts of one original dimension, + * and the outer array represents the original dimensions + * For example, original shape [4, 8] may be split into [[2, 2], [2, 4]] + * \param indices The access indices of the buffer + * \param var_range The iter range of the vars in the indices + * \param rewrite_indices The access indices after rank promotion + * \return The new buffer shape after rank promotion. + */ + static Array> getRankPromotedShape(Array indices, + const Map& var_range, + Array* rewrite_indices) { + Map var_dom = AsIntSet(var_range); + Array> new_shape; + for (const PrimExpr& expr : indices) { + IndexPatternFinder extractor(var_range, rewrite_indices); + arith::IntSet intset = arith::EvalSet(expr, var_dom); + extractor.mod_ = intset.max() + 1; + extractor.div_ = 1; + extractor.offset_ = 0; + extractor(expr); + Array access_shape = extractor.access_shape_; + for (int i = static_cast(access_shape.size()) - 1; i >= 1; i--) { + if (!is_zero(floormod(extractor.offset_, access_shape[i]))) { + return {}; + } else { + extractor.offset_ = floordiv(extractor.offset_, access_shape[i]); + } + } + access_shape.Set(0, extractor.offset_ + access_shape[0]); + new_shape.push_back(access_shape); + } + return new_shape; + } + + private: + void VisitExpr_(const VarNode* op) final { + arith::Analyzer analyzer; + PrimExpr extent = var_range_[GetRef(op)]->extent; + PrimExpr access_iter_range = min(mod_, (max(1, floordiv(extent, div_)))); + if (!analyzer.CanProveEqual(1, access_iter_range)) { + access_shape_.push_back(access_iter_range); + resulting_index_->push_back(floormod(floordiv(GetRef(op), div_), mod_)); + } + } + + void VisitExpr_(const FloorDivNode* op) final { + PrimExpr old_div = div_; + div_ *= op->b; + ExprVisitor::VisitExpr_(op); + div_ = old_div; + } + + void VisitExpr_(const FloorModNode* op) final { + PrimExpr old_mod = mod_; + mod_ = max(1, min(floordiv(op->b, div_), mod_)); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + } + + void VisitExpr_(const MulNode* op) final { + PrimExpr old_mod = mod_; + PrimExpr old_div = div_; + div_ = max(1, floordiv(div_, op->b)); + mod_ = max(1, floordiv(mod_, floordiv(op->b, floordiv(old_div, div_)))); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + div_ = old_div; + } + + void VisitExpr_(const AddNode* op) final { + if (is_const_int(op->b)) { + offset_ += floormod(floordiv(op->b, div_), mod_); + } + ExprVisitor::VisitExpr_(op); + } + + PrimExpr div_; + PrimExpr mod_; + PrimExpr offset_; + Map var_range_; + Array access_shape_; + Array* resulting_index_; +}; + +/*! + * \brief Utilities to perform rank promotion + */ +class RankPromoter : public StmtExprMutator { + public: + /*! + * \brief Flatten the buffer shape like performing inverse rank promotion. + * For example, [[i0, i1], [j0, j1]] to [i0 * i1, j0 * j1] + * \param new_shape The buffer shape in the special form as returned by getRankPromotedShape + * \return The buffer shape after flatten + */ + static Array FlattenNewShape(const Array>& new_shape) { + Array ret; + ret.reserve(new_shape.size()); + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr prod = 1; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + prod *= new_shape[i][j]; + } + ret.push_back(prod); + } + return ret; + } + /** + * \brief Rewrite the index given the shape after rank promotion + * \param indices The original indices + * \param new_shape The buffer shape after rank promotion + * \return The new indices + */ + static Array RewriteIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + ICHECK_EQ(indices.size(), new_shape.size()); + for (int i = 0; i < static_cast(indices.size()); i++) { + PrimExpr index = indices[i]; + // The indices transformed from one original dimension + Array index_dim(new_shape[i].size(), 0); + for (int j = static_cast(new_shape[i].size()) - 1; j >= 0; j--) { + index_dim.Set(j, floormod(index, new_shape[i][j])); + index = floordiv(index, new_shape[i][j]); + } + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + new_indices.push_back(index_dim[j]); + } + } + return new_indices; + } + /*! + * \brief Rewrite the index after buffer flattening + * \param indices The original indices + * \param new_shape The shape before buffer flattening + * \return The indices after buffer flattening + */ + static Array RewriteBackIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + int offset = 0; + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr index = 0; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + index *= new_shape[i][j]; + index += indices[offset + j]; + } + new_indices.push_back(index); + offset += new_shape[i].size(); + } + return new_indices; + } + RankPromoter(const Buffer& src, const Buffer& dst, const Array>& new_shape, + const Array>& relaxed_new_shape, const Array& relaxed_region) + : src_(src), + dst_(dst), + new_shape_(new_shape), + relaxed_new_shape_(relaxed_new_shape), + relaxed_region_(relaxed_region) {} + + static Stmt RewriteBody(Stmt stmt, const Buffer& src, const Buffer& dst, + const Array>& new_shape, + const Array>& relaxed_new_shape, + const Array& relaxed_region) { + RankPromoter promoter(src, dst, new_shape, relaxed_new_shape, relaxed_region); + return promoter(stmt); + } + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + new_store->indices = ConvertIndices(new_store->indices); + return BufferStore(new_store); + } + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + new_load->indices = ConvertIndices(new_load->indices); + return BufferLoad(new_load); + } + return std::move(load); + } + + /*! + * \brief Rewrite the indices after performing buffer rank promotion + + * buffer compacting + buffer flattening. + * \param indices The origina indices + * \return The indices after these transformations + */ + Array ConvertIndices(const Array& indices) { + Array rewrite_indices = RewriteIndex(indices, new_shape_); + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(rewrite_indices.size()); i++) { + rewrite_indices.Set(i, analyzer.Simplify(rewrite_indices[i] - relaxed_region_[i]->min)); + } + return RewriteBackIndex(rewrite_indices, relaxed_new_shape_); + } + + const Buffer& src_; + const Buffer& dst_; + Array> new_shape_; + Array> relaxed_new_shape_; + Array relaxed_region_; +}; + +/*! + * \brief Insert a cache stage to the compute location + * \param stmt the stmt + * \param is_write_cache whether to write a read cache or write cache + * \param storage_scope the storage scope of the new cache + * \param compute_location the compute location. + * \param outer_loops the outer loops of this stmt + * \param alloc_buffer the new cache block + * \return a pair. The first is the stmt after transformation. + * The second is the SeqStmt that contains 2 stages (one original and another inserted). + */ +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + For compute_location, const Array& outer_loops, + Buffer* alloc_buffer) { + Stmt body = stmt; + std::vector loops; + bool need_relax = !compute_location.defined(); + Map relax_var_range; + Map all_var_range; + PrimExpr vector_bytes = -1; + // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into + // several contiguous-changing dimensions + // Step 1.1 collect loop var range for rank promotion + while (const ForNode* loop = body.as()) { + if (need_relax) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } else { + loops.push_back(loop); + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + if (loop == compute_location.get()) { + need_relax = true; + } + if (loop->kind == ForKind::kVectorized) { + vector_bytes = loop->extent; + } + body = loop->body; + } + for (const For& loop : outer_loops) { + if (loop->kind == ForKind::kThreadBinding) { + const String& thread_tag = loop->thread_binding.value()->thread_tag; + if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), + runtime::ThreadScope::Create(thread_tag))) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer orig_buffer = is_write_cache ? buf_store->buffer : buf_load->buffer; + Array indices = is_write_cache ? buf_store->indices : buf_load->indices; + // Step 1.2 get the new shape and new access indices after rank promotion + Array rewrite_indices; + Array> new_shape = + IndexPatternFinder::getRankPromotedShape(indices, all_var_range, &rewrite_indices); + // Step 2. relax the access region after rank promotion + Region relaxed_region; + auto relax_var_intset = AsIntSet(relax_var_range); + arith::Analyzer analyzer; + analyzer.Bind(all_var_range); + for (const PrimExpr& index : rewrite_indices) { + auto int_set = arith::EvalSet(index, relax_var_intset); + relaxed_region.push_back( + Range::FromMinExtent(int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1))); + } + // Step 3. generate the data copy bodies + // preparation work + Array new_loop_vars; + Array orig_buf_indices, new_buf_indices; + Array> relaxed_new_shape; + for (int i = 0; i < static_cast(relaxed_region.size()); i++) { + Var new_loop_var = Var("ax" + std::to_string(i)); + new_loop_vars.push_back(new_loop_var); + orig_buf_indices.push_back(relaxed_region[i]->min + new_loop_var); + new_buf_indices.push_back(new_loop_var); + } + relaxed_new_shape.reserve(new_shape.size()); + for (int i = 0, ct = 0; i < static_cast(new_shape.size()); i++) { + Array layer; + for (int j = 0; j < static_cast(new_shape[i].size()); j++, ct++) { + layer.push_back(relaxed_region[ct]->extent); + } + relaxed_new_shape.push_back(layer); + } + // Step 3.1 create a buffer for the cache + Buffer new_buffer = WithScope(orig_buffer, storage_scope); + BufferNode* buffer_ptr = new_buffer.CopyOnWrite(); + buffer_ptr->shape = RankPromoter::FlattenNewShape(relaxed_new_shape); + *alloc_buffer = new_buffer; + Array real_orig_buf_indices = + RankPromoter::RewriteBackIndex(orig_buf_indices, new_shape); + Array real_new_buf_indices = + RankPromoter::RewriteBackIndex(new_buf_indices, relaxed_new_shape); + // Step 3.2 generate a body that writes to the cache + Stmt generate_body = is_write_cache + ? BufferStore(orig_buffer, BufferLoad(new_buffer, real_new_buf_indices), + real_orig_buf_indices) + : BufferStore(new_buffer, BufferLoad(orig_buffer, real_orig_buf_indices), + real_new_buf_indices); + for (int i = static_cast(relaxed_region.size()) - 1; i >= 0; i--) { + if (i == static_cast(relaxed_region.size()) - 1 && !is_const_int(vector_bytes, -1)) { + ICHECK(analyzer.CanProve(vector_bytes == relaxed_region[i]->extent)); + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kVectorized, generate_body); + } else { + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kSerial, generate_body); + } + } + // Step 3.3 rewrite the original body to load from cache + Stmt rewrite_body; + if (compute_location.defined()) { + rewrite_body = compute_location->body; + } else { + rewrite_body = stmt; + } + rewrite_body = RankPromoter::RewriteBody(rewrite_body, orig_buffer, new_buffer, new_shape, + relaxed_new_shape, relaxed_region); + SeqStmt insert_location; + if (is_write_cache) { + generate_body = insert_location = SeqStmt({rewrite_body, generate_body}); + } else { + generate_body = insert_location = SeqStmt({generate_body, rewrite_body}); + } + generate_body = CopyLoopChain(loops, generate_body); + return std::make_pair(generate_body, insert_location); +} + +Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body; + For compute_location; + std::tie(body, compute_location) = LiftThreadBindingLoops(std::move(stmt)); + Buffer cache_buffer; + Stmt after_caching = InsertCacheStage(body, false, "local", compute_location, + constraints.outer_loops, &cache_buffer) + .first; + output->alloc_buffer.push_back(cache_buffer); + return after_caching; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h new file mode 100644 index 0000000000..c20afecb40 --- /dev/null +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../../include/tvm/arith/iter_affine_map.h" +#include "../../../include/tvm/runtime/registry.h" +#include "../../../include/tvm/target/target.h" +#include "../../../include/tvm/tir/expr.h" +#include "../../../include/tvm/tir/op.h" +#include "../../../include/tvm/tir/stmt_functor.h" +#include "../../../include/tvm/tir/transform.h" +#include "../schedule/utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The set containing all possible constraints of a data copy*/ +struct ConstraintSet { + /*! \brief The extents of the thread binding loops*/ + Map thread_extent; + /*! \brief The outer loops surrounding the data copy*/ + Array outer_loops; + /*! \brief The read region of the data copy*/ + BufferRegion read_region; + /*! \brief The write region of the data copy*/ + BufferRegion write_region; + /*! \brief The dtype size in bits*/ + Integer data_bits; + /*! \brief Whether to insert a local stage in the data copy*/ + Integer add_local_stage = Integer(0); + /*! \brief The vectorization length in bytes*/ + Integer vector_bytes = 1; +}; + +/*! \brief The set containing all possible outpus of a rewrite rule*/ +struct OutputSet { + /*! \brief New buffers allocated after rewrite*/ + Array alloc_buffer; + /*! \brief The minimal padding size of a buffer in base 2 logarithm*/ + Map padding_min; +}; + +/*! + * \brief Rules to rewrite a data copy. + */ +class RewriteRule { + private: + /*! + * \brief Rewrite the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \param output Some additional information that the rewrite rule produces. (including the new + * buffer to be allocated, etc.) + * \return the stmt after rewrite + */ + virtual Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const = 0; + /*! + * \brief Whether the rewrite rule can be applied to the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \return A boolean flag indicating whether the rule can be applied + */ + virtual bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const { return true; } + + public: + inline Stmt Apply(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { + if (CanApply(stmt, constraints)) { + return Rewrite(stmt, constraints, output); + } else { + return stmt; + } + } +}; + +inline bool IsCopyBetweenScope(const Buffer& src_buffer, const Buffer& tgt_buffer, + runtime::StorageRank src_rank, runtime::StorageRank tgt_rank) { + runtime::StorageScope src_scope = runtime::StorageScope::Create(src_buffer.scope()); + runtime::StorageScope tgt_scope = runtime::StorageScope::Create(tgt_buffer.scope()); + return src_scope.rank == src_rank && tgt_scope.rank == tgt_rank; +} + +/*! + * \brief Coalesce and vectorize memory access. + */ +class CoalescedAccess : public RewriteRule { + public: + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Transform from A[f(i,j)] = B[i,j] to A[i,j] = B[f^{-1}(i,j)] + */ +class InverseMapping : public RewriteRule { + public: + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Create a local stage when loading from global memory to shared memory. + */ +class CreateLocalStage : public RewriteRule { + public: + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) && + is_one(constraints.add_local_stage); + } +}; + +/*! + * \brief Add a cache stage in shared memory. Perform tensor core rewrite for wmma->shared, and + * perform coalescing and vectorizing for shared->global. + */ +class WmmaToGlobal : public RewriteRule { + public: + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Rewrite shared->wmma data copy with load_matrix_sync + */ +class SharedToWmma : public RewriteRule { + public: + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixA) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixB); + } +}; + +/*! + * \brief Rewrite wmma->shared data copy with store_matrix_sync + */ +class WmmaToShared : public RewriteRule { + public: + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kShared); + } +}; + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc new file mode 100644 index 0000000000..ea19fbeedd --- /dev/null +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "memhammer_rewrite_rule.h" +namespace tvm { +namespace tir { +/*! + * \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite. + * \param stmt The stmt + * \return A pair. The first is the stmt after transformation. + * The second is the compute location where we may add write cache. + */ +std::pair TileWmmaBlock(Stmt stmt) { + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + arith::Analyzer analyzer; + PrimExpr extent_last1 = loops[loops.size() - 1]->extent, + extent_last2 = loops[loops.size() - 2]->extent; + + if (!analyzer.CanProve(floormod(extent_last1, 16) == 0) || + !analyzer.CanProve(floormod(extent_last2, 16) == 0)) { + return std::make_pair(stmt, For()); + } + std::vector new_loop_vars; + Array factor{floordiv(extent_last2, 16), floordiv(extent_last1, 16), 16, 16}; + new_loop_vars.reserve(4); + for (int i = 0; i < 2; i++) { + new_loop_vars.push_back(loops[loops.size() - 2]->loop_var.copy_with_suffix(std::to_string(i))); + new_loop_vars.push_back(loops[loops.size() - 1]->loop_var.copy_with_suffix(std::to_string(i))); + } + Map substitue_value; + substitue_value.Set(loops[loops.size() - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]); + substitue_value.Set(loops[loops.size() - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]); + body = Substitute(body, substitue_value); + for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; i--) { + body = For(new_loop_vars[i], 0, factor[i], ForKind::kSerial, body); + } + For compute_location = Downcast(body); + for (int i = static_cast(loops.size()) - 3; i >= 0; i--) { + body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, body, + loops[i]->thread_binding, loops[i]->annotations); + } + return std::make_pair(body, compute_location); +} + +/*! + * \brief Rewrite the data copy that stores to wmma fragment with wmma::load_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaLoad(Stmt stmt) { + Array match_buffers; + Stmt body = stmt; + Map var_range; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + for (int i = 1; i <= 2; i++) { + const ForNode* loop = loops[loops.size() - i]; + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + DataType dtype = DataType::Float(16); + Var new_src_var("src", PointerType(PrimType(dtype), src_buffer.scope())); + Type int32 = PrimType(DataType::Int(32)); + Buffer new_src_buffer(new_src_var, dtype, {Integer(16), Integer(16)}, + {Var("s1", int32), Var("s0", int32)}, Var("src_elem_offset", int32), "src", + 128, 16, kDefault); + auto read_int_set = arith::EvalSet(buf_load->indices, AsIntSet(var_range)); + Array read_region; + for (int i = 0; i < static_cast(read_int_set.size()); i++) { + read_region.push_back( + read_int_set[i].CoverRange(Range::FromMinExtent(0, src_buffer->shape[i]))); + } + match_buffers.push_back(MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region))); + Var new_tgt_var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())); + Buffer new_tgt_buffer(new_tgt_var, dtype, {Integer(16), Integer(16)}, {}, + Var("tgt_elem_offset", int32), "tgt", 128, 16, kDefault); + auto write_int_set = arith::EvalSet(buf_store->indices, AsIntSet(var_range)); + Array write_region; + for (int i = 0; i < static_cast(write_int_set.size()); i++) { + write_region.push_back( + write_int_set[i].CoverRange(Range::FromMinExtent(0, tgt_buffer->shape[i]))); + } + match_buffers.push_back( + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region))); + + PrimExpr frag_index = floordiv(new_tgt_buffer->elem_offset, 256) + + floordiv(floormod(new_tgt_buffer->elem_offset, 256), 16); + + auto new_src_pointer = Call( + runtime::DataType::Handle(), builtin::tvm_access_ptr(), + {TypeAnnotation(new_src_buffer->dtype), new_src_buffer->data, new_src_buffer->elem_offset, + new_src_buffer->strides[new_src_buffer->strides.size() - 2] * 16, 1}); + + Stmt wmma_body = Evaluate( + Call(runtime::DataType::Handle(), builtin::tvm_load_matrix_sync(), + {new_tgt_buffer->data, 16, 16, 16, frag_index, new_src_pointer, + new_src_buffer->strides[new_src_buffer->strides.size() - 2], StringImm("row_major")})); + wmma_body = BlockRealize( + {}, Bool(true), Block({}, {}, {}, "wmma_load", wmma_body, NullOpt, {}, match_buffers, {})); + for (int i = static_cast(loops.size()) - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, wmma_body, + loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +/*! + * \brief Rewrite the data copy that loads from wmma fragment with wmma::store_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaStore(Stmt stmt) { + Array match_buffers; + Stmt body = stmt; + Map var_range; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + for (int i = 1; i <= 2; i++) { + const ForNode* loop = loops[loops.size() - i]; + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + DataType dtype = DataType::Float(32); + Type int32 = PrimType(DataType::Int(32)); + Var new_src_var("src", PointerType(PrimType(dtype), src_buffer.scope())); + Buffer new_src_buffer(new_src_var, dtype, {Integer(16), Integer(16)}, {}, + Var("src_elem_offset", int32), "src", 128, 16, kDefault); + auto read_int_set = arith::EvalSet(buf_load->indices, AsIntSet(var_range)); + Array read_region; + for (int i = 0; i < static_cast(read_int_set.size()); i++) { + read_region.push_back( + read_int_set[i].CoverRange(Range::FromMinExtent(0, src_buffer->shape[i]))); + } + match_buffers.push_back(MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region))); + Var new_tgt_var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())); + Buffer new_tgt_buffer(new_tgt_var, dtype, {Integer(16), Integer(16)}, + {Var("s1", int32), Var("s0", int32)}, Var("tgt_elem_offset", int32), "tgt", + 128, 16, kDefault); + auto write_int_set = arith::EvalSet(buf_store->indices, AsIntSet(var_range)); + Array write_region; + for (int i = 0; i < static_cast(write_int_set.size()); i++) { + write_region.push_back( + write_int_set[i].CoverRange(Range::FromMinExtent(0, tgt_buffer->shape[i]))); + } + match_buffers.push_back( + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region))); + + PrimExpr frag_index = floordiv(new_src_buffer->elem_offset, 256) + + floordiv(floormod(new_src_buffer->elem_offset, 256), 16); + + auto new_tgt_pointer = Call(runtime::DataType::Handle(), builtin::tvm_access_ptr(), + {TypeAnnotation(new_tgt_buffer->dtype), new_tgt_buffer->data, + new_tgt_buffer->elem_offset, new_tgt_buffer->strides[0] * 16, 2}); + + Stmt wmma_body = Evaluate(Call(runtime::DataType::Handle(), builtin::tvm_store_matrix_sync(), + {new_src_buffer->data, 16, 16, 16, frag_index, new_tgt_pointer, + new_tgt_buffer->strides[0], StringImm("row_major")})); + wmma_body = BlockRealize( + {}, Bool(true), Block({}, {}, {}, "wmma_store", wmma_body, NullOpt, {}, match_buffers, {})); + for (int i = static_cast(loops.size()) - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, wmma_body, + loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +Stmt SharedToWmma::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.read_region->buffer, 3); + return RewriteWmmaLoad(after_tiling); +} + +Stmt WmmaToShared::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.write_region->buffer, 3); + return RewriteWmmaStore(after_tiling); +} + +class WmmaToGlobalRewriter : public StmtExprMutator { + public: + WmmaToGlobalRewriter(const SeqStmtNode* tgt_stmt, const ConstraintSet& constraints) + : tgt_stmt_(tgt_stmt), constraints_(constraints) {} + + private: + Stmt VisitStmt_(const SeqStmtNode* op) final { + if (op == tgt_stmt_) { + ICHECK_EQ(op->seq.size(), 2); + Stmt wmma_to_shared = RewriteWmmaStore(op->seq[0]); + Stmt shared_to_global = CoalescedAccess().Rewrite(op->seq[1], constraints_, nullptr); + return SeqStmt({wmma_to_shared, shared_to_global}); + } else { + return StmtMutator::VisitStmt_(op); + } + } + + const SeqStmtNode* tgt_stmt_; + const ConstraintSet& constraints_; +}; + +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + For compute_location, const Array& outer_loops, + Buffer* alloc_buffer); + +Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body; + For compute_location; + std::tie(body, compute_location) = TileWmmaBlock(stmt); + SeqStmt seq; + Array outer_loops = constraints.outer_loops; + Buffer cache_buffer; + // Step 1. add a shared memory cache + std::tie(body, seq) = + InsertCacheStage(body, true, "shared.dyn", compute_location, outer_loops, &cache_buffer); + output->alloc_buffer.push_back(cache_buffer); + output->padding_min.Set(cache_buffer, 3); + // Step 2. do coalesced rewrite and tensor core rewrite respectively for 2 parts + WmmaToGlobalRewriter rewriter(seq.get(), constraints); + return rewriter(body); +} + +} // namespace tir +} // namespace tvm