From 49cb371214ce6b1a41a198f606a93afe703c7ce7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 26 Nov 2021 10:56:34 -0800 Subject: [PATCH 1/5] ... --- include/tvm/meta_schedule/feature_extractor.h | 14 +- .../feature_extractor/per_store_feature.py | 46 + .../feature_extractor/per_store_feature.cc | 1180 +++++++++++++++++ src/meta_schedule/utils.h | 1 + src/support/nd_int_set.h | 23 + src/tir/schedule/utils.h | 30 +- .../transforms/convert_blocks_to_opaque.cc | 2 +- 7 files changed, 1285 insertions(+), 11 deletions(-) create mode 100644 python/tvm/meta_schedule/feature_extractor/per_store_feature.py create mode 100644 src/meta_schedule/feature_extractor/per_store_feature.cc diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 30e2f0fe62..ee5d94c13c 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -91,6 +91,18 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { */ class FeatureExtractor : public runtime::ObjectRef { public: + /*! + * \brief Create a feature extractor that extracts features from each BufferStore + * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if + * necessary. + * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity + * curve. + * \param cache_line_bytes The number of bytes in a cache line. + * \return The feature extractor created. + */ + TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5, + int arith_intensity_curve_num_samples = 10, + int cache_line_bytes = 64); /*! * \brief Create a feature extractor with customized methods on the python-side. * \param f_extract_from The packed function of `ExtractFrom`. @@ -98,7 +110,7 @@ class FeatureExtractor : public runtime::ObjectRef { * \return The feature extractor created. */ TVM_DLL static FeatureExtractor PyFeatureExtractor( - PyFeatureExtractorNode::FExtractFrom f_extract_from, // + PyFeatureExtractorNode::FExtractFrom f_extract_from, PyFeatureExtractorNode::FAsString f_as_string); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode); }; diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py new file mode 100644 index 0000000000..a5283d2fce --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""We extract one feature vector per BufferStoreNode statement in a TIR Stmt, +so we call this feature as "per-store" feature. +""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .feature_extractor import FeatureExtractor + + +@register_object("meta_schedule.PerStoreFeature") +class PerStoreFeature(FeatureExtractor): + """PerStoreFeature extracts one feature vector per BufferStoreNode""" + + buffers_per_store: int + arith_intensity_curve_num_samples: int # pylint: disable=invalid-name + cache_line_bytes: int + feature_vector_length: int + + def __init__( + self, + buffers_per_store: int = 5, + arith_intensity_curve_num_samples: int = 10, + cache_line_bytes: int = 64, + ): + self.__init_handle_by_constructor__( + _ffi_api.FeatureExtractorPerStoreFeature, # type: ignore # pylint: disable=no-member + buffers_per_store, + arith_intensity_curve_num_samples, + cache_line_bytes, + ) diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc new file mode 100644 index 0000000000..cc7a1a128d --- /dev/null +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -0,0 +1,1180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +using MultiIndex = std::vector; +using IntVec = std::vector; +using ForVec = std::vector; + +template +using ForBufferMap = std::unordered_map>; + +inline double slog(double x) { + if (x < 0) { + x = -x; + } + return std::log2(x + 1); +} + +namespace utils { + +int64_t GetPragmaAutoUnroll(const ForNode* loop) { + if (Optional auto_unroll = GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { + return auto_unroll.value()->value; + } + return -1; +} + +int64_t ProdLoopExtent(const ForVec& loops) { + int64_t prod = 1; + for (const ForNode* loop : loops) { + if (const int64_t* extent = GetLoopIntExtent(loop)) { + prod *= *extent; + } + } + return prod; +} + +int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) { + if (!loops.empty()) { + if (const int64_t* extent = GetLoopIntExtent(loops[0])) { + return *extent; + } + } + return default_value; +} + +IntVec UnionAndGetRelaxedSize(const std::vector& multi_indices, int64_t* numel, + arith::Analyzer* analyzer) { + if (multi_indices.empty()) { + return {}; + } + int n_indices = multi_indices.size(); + int ndim = multi_indices[0].size(); + IntVec access_shape(ndim, 0); + for (int i = 0; i < ndim; ++i) { + int64_t minimum = arith::ConstIntBound::kPosInf; + int64_t maximum = arith::ConstIntBound::kNegInf; + for (int j = 0; j < n_indices; ++j) { + arith::ConstIntBound bound = analyzer->const_int_bound(multi_indices[j][i]); + minimum = std::min(minimum, bound->min_value); + maximum = std::max(maximum, bound->max_value); + } + *numel *= maximum - minimum + 1; + access_shape[i] = maximum - minimum + 1; + } + return access_shape; +} + +int64_t GetVarStride(const std::vector& multi_indices, const IntVec& buffer_stride, + const Var& var) { + class CoefficientExtractor : private ExprVisitor { + public: + static int64_t Extract(const PrimExpr& expr, const Var& var) { + CoefficientExtractor extractor(var); + extractor.VisitExpr(expr); + return (extractor.visited_var && !extractor.visited_mul && !extractor.visited_add) + ? 1 + : (extractor.visited_var ? extractor.stride : 0); + } + + private: + explicit CoefficientExtractor(const Var& var) + : var(var), stride(0), visited_var(false), visited_add(false), visited_mul(false) {} + + void VisitExpr_(const MulNode* node) override { + ExprVisitor::VisitExpr_(node); + if (visited_var && !visited_add) { + if (const auto* a = node->a.as()) { + visited_mul = true; + stride = a->value; + } else if (const auto* b = node->b.as()) { + visited_mul = true; + stride = b->value; + } + } + } + + void VisitExpr_(const AddNode* node) override { + ExprVisitor::VisitExpr_(node); + if (visited_var && !visited_mul) { + visited_add = true; + stride = 1; + } + } + + void VisitExpr_(const VarNode* node) override { + if (node == var.get()) { + visited_var = true; + stride = 2; + } + } + + const Var& var; + int64_t stride; + bool visited_var; + bool visited_add; + bool visited_mul; + }; + + constexpr int64_t kNotFound = std::numeric_limits::max(); + int ndim = buffer_stride.size(); + // Calculate the min stride possible + int64_t result = kNotFound; + for (const MultiIndex& multi_index : multi_indices) { + ICHECK_EQ(multi_index.size(), buffer_stride.size()); + // Find the rightest dimension that contains the given variable + for (int i = ndim - 1; i >= 0; --i) { + int64_t coef = CoefficientExtractor::Extract(multi_index[i], var); + if (coef != 0) { + result = std::min(result, std::abs(coef) * buffer_stride[i]); + break; + } + } + } + return (result == kNotFound) ? 0 : result; +} + +runtime::NDArray AsNDArray(const std::vector>& src) { + ICHECK(!src.empty()); + int n = src.size(); + int m = src[0].size(); + runtime::NDArray tgt = runtime::NDArray::Empty( + /*shape=*/{n, m}, + /*dtype=*/DLDataType{kDLFloat, 64, 1}, // + /*ctx=*/DLDevice{kDLCPU, 0}); + double* data = static_cast(tgt->data); + for (const std::vector& row : src) { + for (double v : row) { + *data++ = v; + } + } + return tgt; +} + +} // namespace utils + +namespace transform { + +Pass SimplifyConstMatrix() { + class Simplifier : private StmtExprMutator { + public: + static Stmt Run(Stmt stmt) { return Simplifier()(std::move(stmt)); } + + private: + PrimExpr VisitExpr_(const SelectNode* node) { return make_const(node->dtype, 1.0); } + }; + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + PrimFuncNode* n = f.CopyOnWrite(); + n->body = Simplifier::Run(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyConstMatrix", {}); +} + +Sequential PassListForPerStoreFeature() { + return Sequential({ + tir::transform::SimplifyConstMatrix(), + tir::transform::LowerCrossThreadReduction(), + tir::transform::LowerInitBlock(), + tir::transform::PlanAndUpdateBufferAllocationLocation(), + tir::transform::ConvertBlocksToOpaque(), + tir::transform::UnifyThreadBinding(), + tir::transform::CompactBufferAllocation(), + tir::transform::LowerMatchBuffer(), + }); +} + +} // namespace transform + +struct LoopNest { + int64_t prod = 1; + ForVec loops; + IntVec auto_unroll; + ForVec parallel; + ForVec vectorize; + ForVec unroll; + ForVec blockIdx_x; + ForVec blockIdx_y; + ForVec blockIdx_z; + ForVec threadIdx_x; + ForVec threadIdx_y; + ForVec threadIdx_z; + ForVec vthread; + + ForVec* Push(const ForNode* loop, int64_t auto_unroll_attr) { + if (const int64_t* extent = GetLoopIntExtent(loop)) { + this->prod *= *extent; + } + this->loops.push_back(loop); + if (auto_unroll_attr > 0) { + this->auto_unroll.push_back(auto_unroll_attr); + } + ForVec* ref_loops = nullptr; + if (loop->kind == ForKind::kParallel) { + ref_loops = ∥ + } else if (loop->kind == ForKind::kVectorized) { + ref_loops = &vectorize; + } else if (loop->kind == ForKind::kUnrolled) { + ref_loops = &unroll; + } else if (loop->kind == ForKind::kThreadBinding) { + std::string thread_tag = loop->thread_binding.value()->thread_tag; + if (thread_tag == "blockIdx.x") { + ref_loops = &blockIdx_x; + } else if (thread_tag == "blockIdx.y") { + ref_loops = &blockIdx_y; + } else if (thread_tag == "blockIdx.z") { + ref_loops = &blockIdx_z; + } else if (thread_tag == "threadIdx.x") { + ref_loops = &threadIdx_x; + } else if (thread_tag == "threadIdx.y") { + ref_loops = &threadIdx_y; + } else if (thread_tag == "threadIdx.z") { + ref_loops = &threadIdx_z; + } else if (support::StartsWith(thread_tag, "vthread")) { + ref_loops = &vthread; + } else { + LOG(FATAL) << "ValueError: Unable to recognize thread tag: " << thread_tag; + } + } + if (ref_loops != nullptr) { + ref_loops->push_back(loop); + } + return ref_loops; + } + + void Pop(const ForNode* loop, ForVec* ref_loops, int auto_unroll_attr) { + if (ref_loops) { + ref_loops->pop_back(); + } + if (auto_unroll_attr > 0) { + this->auto_unroll.pop_back(); + } + if (const int64_t* extent = GetLoopIntExtent(loop)) { + this->prod /= *extent; + } + } +}; + +/****** Group 1: Computation related features ******/ + +namespace group1 { + +struct Feature { + struct ArithOps { + // Float-point arithmetic features + int64_t float_mad = 0; // The number of float MAD (Multiply–add) ops + int64_t float_add_sub = 0; // The number of float add and sub ops + int64_t float_mul = 0; // The number of float multiply ops + int64_t float_div_mod = 0; // The number of float div and mod ops + int64_t float_cmp = 0; // The number of float comparison ops + int64_t float_math_func = 0; // The number of float math func calls + int64_t float_other_func = 0; // The number of other float func calls + // Integer arithmetic features + int64_t int_mad = 0; // The number of integer MAD (Multiply–add) ops + int64_t int_add_sub = 0; // The number of integer add and sub ops + int64_t int_mul = 0; // The number of integer multiply ops + int64_t int_div_mod = 0; // The number of integer div and mod ops + int64_t int_cmp = 0; // The number of integer comparison ops + int64_t int_math_func = 0; // The number of integer math func calls + int64_t int_other_func = 0; // The number of other integer func calls + // Other arithmetic features + int64_t bool_op = 0; // The number of bool ops + int64_t select_op = 0; // The number of select ops + + static constexpr int64_t kCount = 16; + + ArithOps() = default; + ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent); + + void Export(std::vector* v) const { + double vs[] = { + slog(float_mad), slog(float_add_sub), slog(float_mul), slog(float_div_mod), + slog(float_cmp), slog(float_math_func), slog(float_other_func), // + slog(int_mad), slog(int_add_sub), slog(int_mul), slog(int_div_mod), + slog(int_cmp), slog(int_math_func), slog(int_other_func), // + slog(bool_op), slog(select_op), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + }; + + struct ForKindFeature { + enum class Pos : int { + kPosNone = 0, // Does not have this kind of annotation + kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator + kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator + kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator + kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator + kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator + kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator + kPosMixed = 7, // The annotated iterator is a mixed space and reduce iterator + kEnd = 8, + }; + int64_t num = 0; // The number of iterators with the annotation + int64_t prod = 0; // The product of the lengths of iterators with the annotation + int64_t len = 0; // The length of the innermost iterator with the annotation + Pos pos = Pos::kPosMixed; // The position of the iterators with the annotation + + static constexpr int64_t kCount = 11; + + explicit ForKindFeature(const ForVec& loops); + + void Export(std::vector* v) const { + double vs[] = { + slog(num), + slog(prod), + slog(len), + static_cast(static_cast(pos) == 0), + static_cast(static_cast(pos) == 1), + static_cast(static_cast(pos) == 2), + static_cast(static_cast(pos) == 3), + static_cast(static_cast(pos) == 4), + static_cast(static_cast(pos) == 5), + static_cast(static_cast(pos) == 6), + static_cast(static_cast(pos) == 7), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + }; + + ArithOps arith_ops; + ForKindFeature vectorize; + ForKindFeature unroll; + ForKindFeature parallel; + bool is_gpu = false; + int64_t blockIdx_x_len = 1; // The length of blockIdx.x + int64_t blockIdx_y_len = 1; // The length of blockIdx.y + int64_t blockIdx_z_len = 1; // The length of blockIdx.z + int64_t threadIdx_x_len = 1; // The length of threadIdx.x + int64_t threadIdx_y_len = 1; // The length of threadIdx.y + int64_t threadIdx_z_len = 1; // The length of threadIdx.z + int64_t vthread_len = 1; // The length of virtual thread + + static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount * 3 + 7; + + explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, bool is_gpu) + : arith_ops(store, loop_nest.prod), + vectorize(loop_nest.vectorize), + unroll(loop_nest.unroll), + parallel(loop_nest.parallel) { + if (is_gpu) { + this->is_gpu = true; + this->blockIdx_x_len = utils::FirstLoopExtent(loop_nest.blockIdx_x, 1); + this->blockIdx_y_len = utils::FirstLoopExtent(loop_nest.blockIdx_y, 1); + this->blockIdx_z_len = utils::FirstLoopExtent(loop_nest.blockIdx_z, 1); + this->threadIdx_x_len = utils::FirstLoopExtent(loop_nest.threadIdx_x, 1); + this->threadIdx_y_len = utils::FirstLoopExtent(loop_nest.threadIdx_y, 1); + this->threadIdx_z_len = utils::FirstLoopExtent(loop_nest.threadIdx_z, 1); + this->vthread_len = utils::FirstLoopExtent(loop_nest.vthread, 1); + } + } + + void Export(std::vector* v) const { + double vs[] = { + static_cast(is_gpu), // + slog(blockIdx_x_len), slog(blockIdx_y_len), slog(blockIdx_z_len), + slog(threadIdx_x_len), slog(threadIdx_y_len), slog(threadIdx_z_len), + slog(vthread_len), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } +}; + +Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent) { + class ArithOpCounter : public ExprVisitor { + public: +#define TVM_FEATURE_SIMPLE(Type, Counter) \ + void VisitExpr_(const Type* op) final { \ + result_.Counter += this->prod_loop_extent_; \ + ExprVisitor::VisitExpr_(op); \ + } +#define TVM_FEATURE_BINARY(Type, FloatCounter, IntCounter) \ + void VisitExpr_(const Type* op) final { \ + if (op->dtype.is_float()) { \ + result_.FloatCounter += this->prod_loop_extent_; \ + } else { \ + result_.IntCounter += this->prod_loop_extent_; \ + } \ + ExprVisitor::VisitExpr_(op); \ + } + TVM_FEATURE_SIMPLE(AndNode, bool_op); + TVM_FEATURE_SIMPLE(OrNode, bool_op); + TVM_FEATURE_SIMPLE(NotNode, bool_op); + TVM_FEATURE_SIMPLE(SelectNode, select_op); + TVM_FEATURE_BINARY(AddNode, float_add_sub, int_add_sub); + TVM_FEATURE_BINARY(SubNode, float_add_sub, int_add_sub); + TVM_FEATURE_BINARY(MulNode, float_mul, int_mul); + TVM_FEATURE_BINARY(DivNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(ModNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(FloorDivNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(FloorModNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(MaxNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(MinNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(EQNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(NENode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(LTNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(LENode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(GTNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(GENode, float_cmp, int_cmp); +#undef TVM_FEATURE_BINARY +#undef TVM_FEATURE_SIMPLE + + void VisitExpr_(const CallNode* op) final { + static auto op_call_effect_ = Op::GetAttrMap("TCallEffectKind"); + TCallEffectKind effect_kind = op_call_effect_[Downcast(op->op)]; + bool is_pure = + effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; + if (is_pure) { + if (op->dtype.is_float()) { + result_.float_math_func += prod_loop_extent_; + } else { + result_.int_math_func += prod_loop_extent_; + } + } else { + if (op->dtype.is_float()) { + result_.float_other_func += prod_loop_extent_; + } else { + result_.int_other_func += prod_loop_extent_; + } + } + ExprVisitor::VisitExpr_(op); + } + + int64_t prod_loop_extent_; + ArithOps result_; + }; + ArithOpCounter counter; + counter.prod_loop_extent_ = prod_loop_extent; + counter(store->value); + for (const PrimExpr& expr : store->indices) { + counter(expr); + } + *this = counter.result_; +} + +Feature::ForKindFeature::ForKindFeature(const ForVec& loops) { + if (loops.empty()) { + this->num = 0; + this->prod = 1; + this->len = 0; + this->pos = ForKindFeature::Pos::kPosNone; + } else { + const int64_t* last_loop_extent = GetLoopIntExtent(loops.back()); + this->num = loops.size(); + this->prod = utils::ProdLoopExtent(loops); + this->len = last_loop_extent ? *last_loop_extent : 1; + this->pos = ForKindFeature::Pos::kPosMixed; + } +} + +} // namespace group1 + +namespace group2 { + +struct Feature { + enum class AccessType : int { + kRead = 0, // The buffer is read but not written + kWrite = 1, // The buffer is written but not read + kReadWrite = 2, // The buffer is both read and written + kUnknownRW = 3, // Unknown type + kEnd = 4, + }; + enum class ReuseType : int { + kLoopMultipleRead = 0, // Buffer reuse because accessed on each iteration of a loop + kSerialMultipleReadWrite = 1, // Buffer reuse because it is serially accessed + kNoReuse = 2, // No buffer reuse + kEnd = 3, + }; + + struct SubFeature { + // + const BufferNode* buffer = nullptr; + AccessType access_type = AccessType::kUnknownRW; + std::vector multi_indices = {}; + // + /*! \brief loop_accessed_numel[i][...] means the number of elements accessed by loops[i] */ + std::vector> loop_accessed_numel = {}; + IntVec access_shape; + int64_t num_continuous_bytes = 1; + // Stride information + int64_t min_stride = 0; + int64_t innermost_stride = 0; + int64_t prod_non_strided_loop_extent = 0; + // Reuse information + ReuseType reuse_type = ReuseType::kNoReuse; + double reuse_dis_iter = 0.0; + double reuse_dis_bytes = 0.0; + int64_t reuse_ct = 0; + // Features + double bytes; // The touched memory in bytes + double unique_bytes; // The touched unique memory in bytes + double lines; // The number of touched cache lines + double unique_lines; // The number touched unique cache lines + double bytes_d_reuse_ct; // bytes / reuse_ct + double unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct + double lines_d_reuse_ct; // lines / reuse_ct + double unique_lines_d_reuse_ct; // unique_lines / reuse_ct + double stride; // The stride in access + + static constexpr int64_t kCount = 18; + + void Export(std::vector* v) const { + double vs[] = { + static_cast(static_cast(access_type) == 0), + static_cast(static_cast(access_type) == 1), + static_cast(static_cast(access_type) == 2), + // FeatureSet::BufferAccess::AccessType::kUnknownRW is ignored + slog(bytes), + slog(unique_bytes), + slog(lines), + slog(unique_lines), + static_cast(static_cast(reuse_type) == 0), + static_cast(static_cast(reuse_type) == 1), + static_cast(static_cast(reuse_type) == 2), + slog(reuse_dis_iter), + slog(reuse_dis_bytes), + slog(reuse_ct), + slog(bytes_d_reuse_ct), + slog(unique_bytes_d_reuse_ct), + slog(lines_d_reuse_ct), + slog(unique_lines_d_reuse_ct), + slog(stride), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + static void Pad(std::vector* v) { v->insert(v->end(), 18, 0.0); } + + void SetStride(const LoopNest& loop_nest); + + void SetReuse(const LoopNest& loop_nest, // + int64_t top_loop_touch_bytes, // + const ForBufferMap& buffer_touched_under_loop); + + void SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes); + + explicit SubFeature(const BufferNode* buffer, AccessType access_type, + std::vector multi_indices, int n_loops) + : buffer(buffer), + access_type(access_type), + multi_indices(multi_indices), + loop_accessed_numel(n_loops) {} + }; + + void Export(std::vector* v, int buffers_per_store) const { + int n = sub_features.size(); + for (int i = 0; i < buffers_per_store; ++i) { + if (i < n) { + sub_features[i].Export(v); + } else { + SubFeature::Pad(v); + } + } + for (const SubFeature& sub_feature : sub_features) { + sub_feature.Export(v); + } + } + + explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, + int64_t cache_line_bytes, IntVec* for_touched_bytes, arith::Analyzer* analyzer); + + void Init(const BufferStoreNode* store, int n_loops); + + void SetRegion(const LoopNest& loop_nest, // + IntVec* for_touched_bytes, // + ForBufferMap* buffer_touched_under_loop, // + arith::Analyzer* analyzer); + + std::vector sub_features; +}; + +void Feature::Init(const BufferStoreNode* store, int n_loops) { + struct Info { + AccessType access_type = AccessType::kUnknownRW; + std::vector multi_indices; + }; + std::unordered_map buffer_info; + buffer_info[store->buffer.get()].access_type = AccessType::kWrite; + PostOrderVisit(store->value, [&buffer_info](const ObjectRef& obj) -> void { + if (const BufferLoadNode* load = obj.as()) { + const BufferNode* buffer = load->buffer.get(); + Info& info = buffer_info[buffer]; + switch (info.access_type) { + case AccessType::kRead: + break; + case AccessType::kWrite: + info.access_type = AccessType::kReadWrite; + break; + case AccessType::kReadWrite: + break; + case AccessType::kUnknownRW: + default: + info.access_type = AccessType::kRead; + break; + } + if (info.access_type != AccessType::kReadWrite) { + info.multi_indices.push_back({load->indices.begin(), load->indices.end()}); + } + } + }); + this->sub_features.reserve(buffer_info.size()); + for (const auto& kv : buffer_info) { + this->sub_features.emplace_back(kv.first, kv.second.access_type, + std::move(kv.second.multi_indices), n_loops); + } +} + +void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes, + ForBufferMap* buffer_touched_under_loop, + arith::Analyzer* analyzer) { + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + // Step 1. Initialize and bind all the loop variables to a constant + *for_touched_bytes = IntVec(n_loops, 0); + for (int i = 0; i < n_loops; ++i) { + const ForNode* loop = loops[i]; + analyzer->Bind(loop->loop_var, loop->min, /*allow_override=*/true); + } + // Step 2. Corner case: no loops + if (n_loops == 0) { + // In this case, the `access_shape` is not calculated + for (SubFeature& feature : sub_features) { + feature.access_shape = IntVec(feature.buffer->shape.size(), 1); + } + return; + } + // Step 3. Gradually bind the loops from inner to outer, + // calculate the area the loops touch on each buffer + for (int i = 0; i < n_loops; ++i) { + const ForNode* loop = loops[i]; + analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent), + /*allow_override=*/true); + int64_t& touched_bytes = (*for_touched_bytes)[i] = 0; + for (SubFeature& feature : sub_features) { + const BufferNode* buffer = feature.buffer; + // Note: `feature.access_shape` for `i == n_loops - 1` is the only one preserved, + // while others are discarded + int64_t numel = 1; + feature.access_shape = utils::UnionAndGetRelaxedSize(feature.multi_indices, &numel, analyzer); + feature.loop_accessed_numel[i][buffer] = numel; + touched_bytes += numel * buffer->dtype.bytes(); + (*buffer_touched_under_loop)[loop][buffer].push_back(numel); + } + } +} + +void Feature::SubFeature::SetStride(const LoopNest& loop_nest) { + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + // For each buffer, we find the loop stride on it + const BufferNode* buffer = this->buffer; + int ndim = this->buffer->shape.size(); + IntVec buffer_shape = support::AsVector(buffer->shape); + // Calculate the buffer's stride from its shape + IntVec buffer_stride(ndim); + if (ndim >= 1) { + buffer_stride[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; --i) { + buffer_stride[i] = buffer_stride[i + 1] * buffer_shape[i + 1]; + } + } + // Calculate `num_continuous_bytes` + { + int64_t& num_continuous_bytes = this->num_continuous_bytes = 1; + const IntVec& access_shape = this->access_shape; + ICHECK_EQ(access_shape.size(), buffer_shape.size()); + for (int i = ndim - 1; i >= 0; --i) { + if (access_shape[i] == buffer_shape[i]) { + // TODO + num_continuous_bytes = buffer_shape[i] * buffer->dtype.bytes(); + break; + } + } + } + // Enumerate loops from inner to outer + int i = 0; + // Calculate this->min_stride + int64_t& stride = this->min_stride = 0; + for (; i < n_loops; ++i) { + stride = utils::GetVarStride(this->multi_indices, buffer_stride, loops[i]->loop_var); + if (stride != 0) { + break; + } + } + // Calculate this->innermost_stride + this->innermost_stride = (i == 0) ? stride : 0; + // Calculate this->prod + int64_t& prod = this->prod_non_strided_loop_extent = 1; + for (int j = 0; j < i; ++j) { + if (const int64_t* extent = GetLoopIntExtent(loops[j])) { + prod *= *extent; + } + } +} + +void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t top_loop_touch_bytes, + const ForBufferMap& buffer_touched_under_loop) { + const BufferNode* buffer = this->buffer; + // Step 0. Collect all `Var`s that appears in the buffer region + std::unordered_set region_vars; + for (const MultiIndex& multi_index : this->multi_indices) { + for (const PrimExpr& index : multi_index) { + PostOrderVisit(index, [®ion_vars](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + region_vars.insert(var); + } + }); + } + } + // Default case: no reuse + ReuseType& reuse_type = this->reuse_type = ReuseType::kNoReuse; + double& reuse_dis_iter = this->reuse_dis_iter = 0; + double& reuse_dis_bytes = this->reuse_dis_bytes = 0; + int64_t& reuse_ct = this->reuse_ct = 0; + + // Step 3.2. Enumerate loops from inner to outer, find the first loop with reuse + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + for (int i = 0; i < n_loops; ++i) { + const ForNode* loop = loops[i]; + // Case 1. Find an invariant loop, i.e. reuse with kLoopMultipleRead + if (!region_vars.count(loop->loop_var.get())) { + reuse_type = ReuseType::kLoopMultipleRead; + if (const int64_t* extent = GetLoopIntExtent(loop)) { + reuse_ct = *extent; + } else { + reuse_ct = 1; + } + reuse_dis_iter = 1; + for (int j = 0; j < i; ++j) { + if (const int64_t* extent = GetLoopIntExtent(loops[j])) { + reuse_dis_iter *= *extent; + } + } + reuse_dis_bytes = 0.0; + if (i == 0) { + reuse_dis_bytes = top_loop_touch_bytes; + } else { + for (const auto& iter : buffer_touched_under_loop.at(loops[i - 1])) { + const BufferNode* buffer = iter.first; + const IntVec& numels = iter.second; + int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0)); + reuse_dis_bytes += numel * buffer->dtype.bytes(); + } + } + break; + } + // Case 2. Find serial reuse, i.e. reuse with kSerialMultipleReadWrite + const IntVec& touched = buffer_touched_under_loop.at(loop).at(buffer); + if (touched.size() >= 2) { + int64_t extent = 1; + if (const int64_t* ext = GetLoopIntExtent(loop)) { + extent = *ext; + } + reuse_type = ReuseType::kSerialMultipleReadWrite; + reuse_ct = touched.size() - 1; + reuse_dis_iter = *std::min_element(touched.begin(), touched.end()); + reuse_dis_bytes = 0.0; + for (const auto& iter : buffer_touched_under_loop.at(loop)) { + const BufferNode* buffer = iter.first; + const IntVec& numels = iter.second; + int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0)); + reuse_dis_bytes += numel * buffer->dtype.bytes(); + } + reuse_dis_iter /= extent; + reuse_dis_bytes /= extent; + break; + } + } +} + +void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes) { + int64_t dtype_bytes = this->buffer->dtype.bytes(); + this->stride = this->innermost_stride; + this->bytes = dtype_bytes * loop_nest.prod; + if (loop_nest.loops.empty()) { + this->unique_bytes = 1; + this->lines = 1; + this->unique_lines = 1; + } else { + this->unique_bytes = this->loop_accessed_numel.back().at(buffer) * dtype_bytes; + double m = static_cast(this->min_stride) * dtype_bytes / cache_line_bytes; + this->lines = + static_cast(loop_nest.prod) / this->prod_non_strided_loop_extent * std::min(1.0, m); + this->lines = std::max(1.0, this->lines); + this->unique_lines = static_cast(this->unique_bytes) / + std::min(cache_line_bytes, this->num_continuous_bytes); + this->unique_lines = std::max(1.0, this->unique_lines); + } + double proxy_reuse_ct = this->reuse_ct > 0 ? this->reuse_ct : 0.5; + this->bytes_d_reuse_ct = this->bytes / proxy_reuse_ct; + this->unique_bytes_d_reuse_ct = this->unique_bytes / proxy_reuse_ct; + this->lines_d_reuse_ct = this->lines / proxy_reuse_ct; + this->unique_lines_d_reuse_ct = this->unique_lines / proxy_reuse_ct; +} + +Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_t cache_line_bytes, + IntVec* for_touched_bytes, arith::Analyzer* analyzer) { + int n_loops = loop_nest.loops.size(); + // Step 0. Initialize data structures + this->Init(store, n_loops); + // Step 1. Calculate region-related feature + ForBufferMap buffer_touched_under_loop; + this->SetRegion(loop_nest, for_touched_bytes, &buffer_touched_under_loop, analyzer); + // Step 2. Calculate stride-related feature + for (auto& feature : sub_features) { + feature.SetStride(loop_nest); + } + // Step 3. Calculate reuse-related feature + int64_t top_loop_touch_bytes = 0.0; + if (n_loops > 0) { + for (const SubFeature& feature : sub_features) { + int64_t bytes = feature.buffer->dtype.bytes(); + int64_t n_buffer = feature.loop_accessed_numel[0].size(); + top_loop_touch_bytes += bytes * n_buffer; + } + } + for (auto& feature : sub_features) { + feature.SetReuse(loop_nest, top_loop_touch_bytes, buffer_touched_under_loop); + } + // Step 4. Calculate rest of the features + for (auto& feature : sub_features) { + feature.SetFeature(loop_nest, cache_line_bytes); + } + // Step 5. Sort the features + std::sort(sub_features.begin(), sub_features.end(), [](const SubFeature& a, const SubFeature& b) { + if (a.lines != b.lines) { + return a.lines > b.lines; + } + if (a.bytes != b.bytes) { + return a.bytes > b.bytes; + } + return a.buffer->name < b.buffer->name; + }); +} + +} // namespace group2 + +namespace group3 { + +struct Feature { + std::vector arith_intensity_curve; + + void Export(std::vector* v) const { + v->insert(v->end(), arith_intensity_curve.begin(), arith_intensity_curve.end()); + } + + explicit Feature(int n_samples, const LoopNest& loop_nest, const IntVec& for_touched_bytes, + const group1::Feature::ArithOps& arith_ops) + : arith_intensity_curve(n_samples, 0.0) { + const std::vector& loops = loop_nest.loops; + ICHECK_EQ(loops.size(), for_touched_bytes.size()); + int n_loops = loops.size(); + // Calculate `memory_bytes` + std::vector memory_bytes; + for (int i = 0; i < n_loops; ++i) { + memory_bytes.push_back(std::log2(for_touched_bytes[i])); + } + // Calculate `compute_ops` and `cur_compute_ops` + std::vector compute_ops; + double total_compute_ops = arith_ops.float_mad + arith_ops.float_add_sub + arith_ops.float_mul + + arith_ops.float_div_mod + arith_ops.float_cmp + + arith_ops.float_math_func + arith_ops.float_other_func; + total_compute_ops /= loop_nest.prod; + for (int i = 0; i < n_loops; ++i) { + if (const int64_t* extent = GetLoopIntExtent(loops[i])) { + total_compute_ops *= *extent; + } + compute_ops.push_back(std::log2(total_compute_ops)); + } + // Fill the feature set + if (total_compute_ops <= 0 || compute_ops.empty()) { + for (int i = 0; i < n_samples; ++i) { + arith_intensity_curve[i] = 0.0; + } + return; + } + total_compute_ops = compute_ops.back(); // i.e. total_compute_ops = log2(total_compute_ops) + int p = 0; + for (int i = 0; i < n_samples; ++i) { + double& result = arith_intensity_curve[i]; + double cur_compute_ops = static_cast(i + 1) / n_samples * total_compute_ops; + // Find the first `p` that `compute[p] >= total * (i + 1) / N` + for (; p < n_loops; ++p) { + if (compute_ops[p] >= cur_compute_ops - 1e-4) { + break; + } + } + CHECK_LT(p, n_loops); + if (p == 0) { + result = compute_ops[p] / memory_bytes[p]; + } else { + double base = compute_ops[p - 1] / memory_bytes[p - 1]; + double slope = + (compute_ops[p] / memory_bytes[p] - compute_ops[p - 1] / memory_bytes[p - 1]) / + (compute_ops[p] - compute_ops[p - 1]); + result = base + slope * (cur_compute_ops - compute_ops[p - 1]); + } + } + } +}; + +} // namespace group3 + +namespace group4 { + +struct Feature { + int64_t alloc_size = 0; // The size of allocated buffer in bytes + int64_t alloc_prod = 0; // alloc_outer_prod * alloc_inner_prod + int64_t alloc_outer_prod = 1; // The product of lengths of loops outside the scope of the alloc + + static constexpr int64_t kCount = 4; + + void Export(std::vector* v, int64_t outer_prod) const { + double vs[] = { + slog(alloc_size), + slog(alloc_prod), + slog(alloc_outer_prod), + slog(static_cast(outer_prod) / alloc_outer_prod), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + Feature() = default; + + explicit Feature(const LoopNest& loop_nest, const Buffer& buffer) { + int64_t numel = 1; + for (int64_t x : support::AsVector(buffer->shape)) { + numel *= x; + } + alloc_size = numel * buffer->dtype.bytes(); + alloc_prod = numel * loop_nest.prod; + alloc_outer_prod = loop_nest.prod; + } +}; + +} // namespace group4 + +namespace group5 { + +struct Feature { + int64_t outer_prod; // The product of lengths of outer loops + int num_loops; // The number of outer loops + int auto_unroll_max_step; // The value of pragma "auto_unroll_max_step" + + static constexpr int64_t kCount = 3; + + void Export(std::vector* v) const { + double vs[] = { + slog(outer_prod), + slog(num_loops), + slog(auto_unroll_max_step), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + explicit Feature(const LoopNest& loop_nest) { + this->outer_prod = loop_nest.prod; + this->num_loops = loop_nest.loops.size(); + this->auto_unroll_max_step = loop_nest.auto_unroll.empty() ? 0 : loop_nest.auto_unroll.back(); + } +}; + +} // namespace group5 + +struct Feature { + const BufferNode* buffer = nullptr; + int buffer_order = -1; + std::unique_ptr group1 = nullptr; + std::unique_ptr group2 = nullptr; + std::unique_ptr group3 = nullptr; + std::unique_ptr group4 = nullptr; + std::unique_ptr group5 = nullptr; + + bool operator<(const Feature& other) const { return buffer_order < other.buffer_order; } +}; + +class PerStoreFeatureCollector : private StmtVisitor { + public: + static std::vector Collect(bool is_gpu, int64_t cache_line_bytes, + int64_t arith_intensity_curve_num_samples, + const IRModule& mod) { + PerStoreFeatureCollector collector(is_gpu, cache_line_bytes, arith_intensity_curve_num_samples); + for (const auto& kv : mod->functions) { + if (const PrimFuncNode* func = kv.second.as()) { + collector(func->body); + } + } + std::vector result; + result.reserve(collector.buffer_features_.size()); + for (auto& it : collector.buffer_features_) { + Feature& feature = it.second; + if (feature.buffer != nullptr) { + ICHECK(feature.group1); + ICHECK(feature.group2); + ICHECK(feature.group3); + ICHECK(feature.group5); + if (feature.group4 == nullptr) { + feature.group4 = std::make_unique(); + } + result.push_back(std::move(feature)); + } + } + std::sort(result.begin(), result.end()); + return result; + } + + private: + void VisitStmt_(const ForNode* loop) final { + int64_t auto_unroll = utils::GetPragmaAutoUnroll(loop); + ForVec* for_vec = loop_nest_.Push(loop, auto_unroll); + StmtVisitor::VisitStmt_(loop); + loop_nest_.Pop(loop, for_vec, auto_unroll); + } + + void VisitStmt_(const BufferStoreNode* store) final { + const BufferNode* buffer = store->buffer.get(); + Feature& feature = buffer_features_[buffer]; + if (feature.buffer == nullptr) { + feature.buffer = buffer; + feature.buffer_order = buffer_features_.size(); + } + feature.group1 = std::make_unique(store, loop_nest_, is_gpu_); + feature.group2 = std::make_unique(store, loop_nest_, cache_line_bytes_, + &for_touched_bytes_, &analyzer_); + feature.group3 = std::make_unique(arith_intensity_curve_num_samples_, // + loop_nest_, for_touched_bytes_, + feature.group1->arith_ops); + feature.group5 = std::make_unique(loop_nest_); + } + + void VisitStmt_(const BlockNode* block) final { + StmtVisitor::VisitStmt_(block); + for (const Buffer& buffer : block->alloc_buffers) { + Feature& feature = buffer_features_[buffer.get()]; + feature.group4 = std::make_unique(loop_nest_, buffer); + } + } + + explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes, + int64_t arith_intensity_curve_num_samples) + : is_gpu_(is_gpu), + cache_line_bytes_(cache_line_bytes), + arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples) {} + + bool is_gpu_; + int64_t cache_line_bytes_; + int64_t arith_intensity_curve_num_samples_; + arith::Analyzer analyzer_; + LoopNest loop_nest_ = {}; + IntVec for_touched_bytes_ = {}; + std::unordered_map buffer_features_ = {}; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +class PerStoreFeatureNode : public FeatureExtractorNode { + public: + int buffers_per_store; + int arith_intensity_curve_num_samples; + int cache_line_bytes; + int feature_vector_length; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("buffers_per_store", &buffers_per_store); + v->Visit("arith_intensity_curve_num_samples", &arith_intensity_curve_num_samples); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("feature_vector_length", &feature_vector_length); + } + + void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { + std::vector features = tir::PerStoreFeatureCollector::Collect( + is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod); + int n_features = features.size(); + results->resize(n_features); + for (int i = 0; i < n_features; ++i) { + const tir::Feature& feature = features[i]; + std::vector& result = (*results)[i]; + result.reserve(feature_vector_length); + feature.group1->Export(&result); + feature.group2->Export(&result, this->buffers_per_store); + feature.group3->Export(&result); + feature.group4->Export(&result, feature.group5->outer_prod); + feature.group5->Export(&result); + ICHECK_EQ(static_cast(result.size()), feature_vector_length); + } + } + + Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) { + bool is_gpu = tune_context->target.value()->kind->name == "cuda"; + std::vector results; + results.resize(candidates.size()); + auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void { + const auto& candidate = candidates[task_id]; + std::vector> features; + ExtractSingle(candidate->sch->mod(), is_gpu, &features); + results[task_id] = tir::utils::AsNDArray(features); + }; + support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f); + return results; + } + + static constexpr const char* _type_key = "meta_schedule.PerStoreFeature"; + TVM_DECLARE_FINAL_OBJECT_INFO(PerStoreFeatureNode, FeatureExtractorNode); +}; + +FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, + int arith_intensity_curve_num_samples, + int cache_line_bytes) { + ObjectPtr n = make_object(); + n->buffers_per_store = buffers_per_store; + n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples; + n->cache_line_bytes = cache_line_bytes; + n->feature_vector_length = tir::group1::Feature::kCount + // + tir::group2::Feature::SubFeature::kCount * buffers_per_store + // + arith_intensity_curve_num_samples + // + tir::group4::Feature::kCount + // + tir::group5::Feature::kCount; + return FeatureExtractor(n); +} + +TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") + .set_body_typed(FeatureExtractor::PerStoreFeature); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 3a41062be2..8164c01ca7 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -43,6 +43,7 @@ #include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" +#include "../support/utils.h" #include "../tir/schedule/utils.h" namespace tvm { diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index ae4a0386d4..46bbd2bceb 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -144,6 +144,29 @@ inline NDIntSet NDIntSetEval( return ret; } +/*! + * \brief Output the N-dimensional integer set to a stream. + * \param os The output stream. + * \param nd_int_set The N-dimensional integer set to be output. + * \return The output stream. + */ +inline std::ostream& operator<<(std::ostream& os, const NDIntSet& nd_int_set) { + os << '['; + bool is_first = true; + for (const arith::IntSet& int_set : nd_int_set) { + if (is_first) { + is_first = false; + } else { + os << ", "; + } + PrimExpr min = int_set.min(); + PrimExpr max = int_set.max(); + os << min << ":" << max; + } + os << ']'; + return os; +} + } // namespace support } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c955e4f62d..091266ee38 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -300,6 +300,24 @@ inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* /******** Annotation ********/ +/*! + * \brief Get the annotation on a Block/For + * \tparam TObjectRef The type of the annotation value + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be looked up + * \return NullOpt if not found; otherwise the annotation value + */ +template +inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) { + const Map* annotations = &stmt->annotations; + for (const auto& ann : *annotations) { + if (ann.first == ann_key) { + return Downcast(ann.second); + } + } + return NullOpt; +} + /*! * \brief Get the annotation on a Block/For * \tparam TObjectRef The type of the annotation value @@ -309,20 +327,14 @@ inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* */ template inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) { - const Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { - annotations = &loop->annotations; + return GetAnn(loop, ann_key); } else if (const auto* block = sref->StmtAs()) { - annotations = &block->annotations; + return GetAnn(block, ann_key); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + throw; } - for (const auto& ann : *annotations) { - if (ann.first == ann_key) { - return Downcast(ann.second); - } - } - return NullOpt; } /*! diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index f7629d1006..ddc2e17569 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -80,7 +80,7 @@ class OpaqueBlockConverter : public StmtExprMutator { return std::move(new_realize); } - /*! \brief The map from block vars to thier binding values. */ + /*! \brief The map from block vars to their binding values. */ std::unordered_map var_substitutes_; }; From 542dd35ff1c48dda5454b9235fc24ba4a5014f94 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 26 Nov 2021 14:08:29 -0800 Subject: [PATCH 2/5] ... --- python/tvm/meta_schedule/__init__.py | 1 + .../feature_extractor/__init__.py | 1 + .../feature_extractor/per_store_feature.cc | 36 ++++++++++++++++--- src/tir/schedule/utils.h | 13 +++++++ ...ule_feature_extractor_per_store_feature.py | 20 +++++++++++ 5 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 37e8ffa9d8..cec0ba3fe9 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -28,3 +28,4 @@ from . import feature_extractor from . import cost_model from .tune_context import TuneContext +from .search_strategy import MeasureCandidate diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py index 49310decf3..ffe7655a51 100644 --- a/python/tvm/meta_schedule/feature_extractor/__init__.py +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -20,3 +20,4 @@ measure candidates for use in cost model. """ from .feature_extractor import FeatureExtractor, PyFeatureExtractor +from .per_store_feature import PerStoreFeature diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index cc7a1a128d..9b38686c27 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -25,6 +25,25 @@ #include "../utils.h" +namespace std { + +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + tvm::tir::PrintVector(vec, os, [&os](int64_t i) { os << i; }); + return os; +} + +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + tvm::tir::PrintVector(vec, os, [&os](const tvm::PrimExpr& i) { os << i; }); + return os; +} + +std::ostream& operator<<(std::ostream& os, const std::vector>& vec) { + tvm::tir::PrintVector(vec, os, [&os](const std::vector& i) { os << i; }); + return os; +} + +} // namespace std + namespace tvm { namespace tir { @@ -379,7 +398,7 @@ struct Feature { int64_t threadIdx_z_len = 1; // The length of threadIdx.z int64_t vthread_len = 1; // The length of virtual thread - static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount * 3 + 7; + static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount * 3 + 8; explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, bool is_gpu) : arith_ops(store, loop_nest.prod), @@ -399,6 +418,10 @@ struct Feature { } void Export(std::vector* v) const { + this->arith_ops.Export(v); + this->vectorize.Export(v); + this->unroll.Export(v); + this->parallel.Export(v); double vs[] = { static_cast(is_gpu), // slog(blockIdx_x_len), slog(blockIdx_y_len), slog(blockIdx_z_len), @@ -599,9 +622,6 @@ struct Feature { SubFeature::Pad(v); } } - for (const SubFeature& sub_feature : sub_features) { - sub_feature.Export(v); - } } explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, @@ -623,7 +643,11 @@ void Feature::Init(const BufferStoreNode* store, int n_loops) { std::vector multi_indices; }; std::unordered_map buffer_info; - buffer_info[store->buffer.get()].access_type = AccessType::kWrite; + { + Info& info = buffer_info[store->buffer.get()]; + info.access_type = AccessType::kWrite; + info.multi_indices.push_back({store->indices.begin(), store->indices.end()}); + } PostOrderVisit(store->value, [&buffer_info](const ObjectRef& obj) -> void { if (const BufferLoadNode* load = obj.as()) { const BufferNode* buffer = load->buffer.get(); @@ -1121,6 +1145,8 @@ class PerStoreFeatureNode : public FeatureExtractorNode { } void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { + static transform::Sequential passes = tir::transform::PassListForPerStoreFeature(); + mod = passes(std::move(mod)); std::vector features = tir::PerStoreFeatureCollector::Collect( is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod); int n_features = features.size(); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 091266ee38..2c2a7666f5 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -105,6 +105,19 @@ namespace tir { << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") +template +void PrintVector(const std::vector& vec, std::ostream& os, FPrint print) { + int n = vec.size(); + os << "["; + for (int i = 0; i < n; ++i) { + if (i != 0) { + os << ", "; + } + print(vec[i]); + } + os << "]"; +} + /*! * \brief Convert an array of loop StmtSRefs to an array of loops * \param loop_srefs The loop StmtSRefs to be converted diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py new file mode 100644 index 0000000000..b9406abbba --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +if __name__ == "__main__": + pass From 06aa1943098161174dff7dc63374fe40e7061372 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 26 Nov 2021 21:23:44 -0800 Subject: [PATCH 3/5] pass the first test --- .../feature_extractor/per_store_feature.cc | 64 ++-- ...ule_feature_extractor_per_store_feature.py | 335 +++++++++++++++++- 2 files changed, 377 insertions(+), 22 deletions(-) diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 9b38686c27..b69068f5c7 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -32,6 +32,11 @@ std::ostream& operator<<(std::ostream& os, const std::vector& vec) { return os; } +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + tvm::tir::PrintVector(vec, os, [&os](double i) { os << i; }); + return os; +} + std::ostream& operator<<(std::ostream& os, const std::vector& vec) { tvm::tir::PrintVector(vec, os, [&os](const tvm::PrimExpr& i) { os << i; }); return os; @@ -498,9 +503,6 @@ Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t prod_loop_exte ArithOpCounter counter; counter.prod_loop_extent_ = prod_loop_extent; counter(store->value); - for (const PrimExpr& expr : store->indices) { - counter(expr); - } *this = counter.result_; } @@ -698,14 +700,14 @@ void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes, } // Step 3. Gradually bind the loops from inner to outer, // calculate the area the loops touch on each buffer - for (int i = 0; i < n_loops; ++i) { + for (int i = n_loops - 1; i >= 0; --i) { const ForNode* loop = loops[i]; analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent), /*allow_override=*/true); int64_t& touched_bytes = (*for_touched_bytes)[i] = 0; for (SubFeature& feature : sub_features) { const BufferNode* buffer = feature.buffer; - // Note: `feature.access_shape` for `i == n_loops - 1` is the only one preserved, + // Note: `feature.access_shape` for `i == 0` is the only one preserved, // while others are discarded int64_t numel = 1; feature.access_shape = utils::UnionAndGetRelaxedSize(feature.multi_indices, &numel, analyzer); @@ -748,18 +750,18 @@ void Feature::SubFeature::SetStride(const LoopNest& loop_nest) { int i = 0; // Calculate this->min_stride int64_t& stride = this->min_stride = 0; - for (; i < n_loops; ++i) { + for (i = n_loops - 1; i >= 0; --i) { stride = utils::GetVarStride(this->multi_indices, buffer_stride, loops[i]->loop_var); if (stride != 0) { break; } } // Calculate this->innermost_stride - this->innermost_stride = (i == 0) ? stride : 0; + this->innermost_stride = (i == n_loops - 1) ? stride : 0; // Calculate this->prod int64_t& prod = this->prod_non_strided_loop_extent = 1; - for (int j = 0; j < i; ++j) { - if (const int64_t* extent = GetLoopIntExtent(loops[j])) { + for (int j = n_loops - 1; j > i; --j) { + if (const int64_t* extent = GetLoopIntExtent(loops[n_loops - 1])) { // TODO prod *= *extent; } } @@ -788,7 +790,7 @@ void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t top_loop_t // Step 3.2. Enumerate loops from inner to outer, find the first loop with reuse int n_loops = loop_nest.loops.size(); const std::vector& loops = loop_nest.loops; - for (int i = 0; i < n_loops; ++i) { + for (int i = n_loops - 1; i >= 0; --i) { const ForNode* loop = loops[i]; // Case 1. Find an invariant loop, i.e. reuse with kLoopMultipleRead if (!region_vars.count(loop->loop_var.get())) { @@ -799,16 +801,16 @@ void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t top_loop_t reuse_ct = 1; } reuse_dis_iter = 1; - for (int j = 0; j < i; ++j) { + for (int j = n_loops - 1; j > i; --j) { if (const int64_t* extent = GetLoopIntExtent(loops[j])) { reuse_dis_iter *= *extent; } } reuse_dis_bytes = 0.0; - if (i == 0) { + if (i == n_loops - 1) { reuse_dis_bytes = top_loop_touch_bytes; } else { - for (const auto& iter : buffer_touched_under_loop.at(loops[i - 1])) { + for (const auto& iter : buffer_touched_under_loop.at(loops[i + 1])) { const BufferNode* buffer = iter.first; const IntVec& numels = iter.second; int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0)); @@ -850,10 +852,9 @@ void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t cache_li this->lines = 1; this->unique_lines = 1; } else { - this->unique_bytes = this->loop_accessed_numel.back().at(buffer) * dtype_bytes; - double m = static_cast(this->min_stride) * dtype_bytes / cache_line_bytes; - this->lines = - static_cast(loop_nest.prod) / this->prod_non_strided_loop_extent * std::min(1.0, m); + this->unique_bytes = this->loop_accessed_numel.front().at(buffer) * dtype_bytes; + this->lines = static_cast(loop_nest.prod) / this->prod_non_strided_loop_extent * + std::min(1.0, 1.0 * this->min_stride * dtype_bytes / cache_line_bytes); this->lines = std::max(1.0, this->lines); this->unique_lines = static_cast(this->unique_bytes) / std::min(cache_line_bytes, this->num_continuous_bytes); @@ -925,8 +926,9 @@ struct Feature { int n_loops = loops.size(); // Calculate `memory_bytes` std::vector memory_bytes; + memory_bytes.resize(n_loops); for (int i = 0; i < n_loops; ++i) { - memory_bytes.push_back(std::log2(for_touched_bytes[i])); + memory_bytes[n_loops - 1 - i] = std::log2(for_touched_bytes[i]); } // Calculate `compute_ops` and `cur_compute_ops` std::vector compute_ops; @@ -934,7 +936,7 @@ struct Feature { arith_ops.float_div_mod + arith_ops.float_cmp + arith_ops.float_math_func + arith_ops.float_other_func; total_compute_ops /= loop_nest.prod; - for (int i = 0; i < n_loops; ++i) { + for (int i = n_loops - 1; i >= 0; --i) { if (const int64_t* extent = GetLoopIntExtent(loops[i])) { total_compute_ops *= *extent; } @@ -1056,6 +1058,9 @@ class PerStoreFeatureCollector : private StmtVisitor { for (const auto& kv : mod->functions) { if (const PrimFuncNode* func = kv.second.as()) { collector(func->body); + for (const auto& it : func->buffer_map) { + collector.HandleBufferAlloc(it.second); + } } } std::vector result; @@ -1086,6 +1091,7 @@ class PerStoreFeatureCollector : private StmtVisitor { } void VisitStmt_(const BufferStoreNode* store) final { + LOG(INFO) << "Visit BufferStore:\n" << GetRef(store); const BufferNode* buffer = store->buffer.get(); Feature& feature = buffer_features_[buffer]; if (feature.buffer == nullptr) { @@ -1104,11 +1110,15 @@ class PerStoreFeatureCollector : private StmtVisitor { void VisitStmt_(const BlockNode* block) final { StmtVisitor::VisitStmt_(block); for (const Buffer& buffer : block->alloc_buffers) { - Feature& feature = buffer_features_[buffer.get()]; - feature.group4 = std::make_unique(loop_nest_, buffer); + HandleBufferAlloc(buffer); } } + void HandleBufferAlloc(const Buffer& buffer) { + Feature& feature = buffer_features_[buffer.get()]; + feature.group4 = std::make_unique(loop_nest_, buffer); + } + explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes, int64_t arith_intensity_curve_num_samples) : is_gpu_(is_gpu), @@ -1147,6 +1157,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode { void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { static transform::Sequential passes = tir::transform::PassListForPerStoreFeature(); mod = passes(std::move(mod)); + LOG(INFO) << "mod =\n" << tir::AsTVMScript(mod, "T"); std::vector features = tir::PerStoreFeatureCollector::Collect( is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod); int n_features = features.size(); @@ -1161,7 +1172,18 @@ class PerStoreFeatureNode : public FeatureExtractorNode { feature.group4->Export(&result, feature.group5->outer_prod); feature.group5->Export(&result); ICHECK_EQ(static_cast(result.size()), feature_vector_length); + DebugFeature(feature); + } + } + + void DebugFeature(const tir::Feature& feature) { + const tir::group2::Feature& f2 = *feature.group2; + std::ostringstream os; + os << "Feature(" << feature.buffer->name << "):"; + for (const tir::group2::Feature::SubFeature& sub_f : f2.sub_features) { + os << " " << sub_f.buffer->name; } + LOG(INFO) << os.str(); } Array ExtractFrom(const TuneContext& tune_context, diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py index b9406abbba..163fd81326 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -15,6 +15,339 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import Callable, List + +from numpy.testing import assert_allclose +import tvm +from tvm import meta_schedule as ms, te, tir +from tvm.meta_schedule.testing import te_workload + +N_FEATURES = 164 + + +def _make_context(target) -> ms.TuneContext: + return ms.TuneContext( + target=target, + num_threads=1, + ) + + +def _make_candidate(f_sch: Callable[[], tir.Schedule]) -> ms.MeasureCandidate: + return ms.MeasureCandidate(sch=f_sch(), args_info=[]) + + +def _feature_names( # pylint: disable=invalid-name + buffers_per_store: int = 5, + arith_intensity_curve_num_samples: int = 10, +) -> List[str]: + result = [ + "float_mad", + "float_addsub", + "float_mul", + "float_divmod", + "float_cmp", + "float_mathfunc", + "float_otherfunc", + "int_mad", + "int_addsub", + "int_mul", + "int_divmod", + "int_cmp", + "int_mathfunc", + "int_otherfunc", + "bool_op", + "select_op", + "vec_num", + "vec_prod", + "vec_len", + "vec_type.kPosNone", + "vec_type.kPosInnerSpatial", + "vec_type.kPosMiddleSpatial", + "vec_type.kPosOuterSpatial", + "vec_type.kPosInnerReduce", + "vec_type.kPosMiddleReduce", + "vec_type.kPosOuterReduce", + "vec_type.kPosMixed", + "unroll_num", + "unroll_prod", + "unroll_len", + "unroll_type.kPosNone", + "unroll_type.kPosInnerSpatial", + "unroll_type.kPosMiddleSpatial", + "unroll_type.kPosOuterSpatial", + "unroll_type.kPosInnerReduce", + "unroll_type.kPosMiddleReduce", + "unroll_type.kPosOuterReduce", + "unroll_type.kPosMixed", + "parallel_num", + "parallel_prod", + "parallel_len", + "parallel_type.kPosNone", + "parallel_type.kPosInnerSpatial", + "parallel_type.kPosMiddleSpatial", + "parallel_type.kPosOuterSpatial", + "parallel_type.kPosInnerReduce", + "parallel_type.kPosMiddleReduce", + "parallel_type.kPosOuterReduce", + "parallel_type.kPosMixed", + "is_gpu", + "blockIdx_x_len", + "blockIdx_y_len", + "blockIdx_z_len", + "threadIdx_x_len", + "threadIdx_y_len", + "threadIdx_z_len", + "vthread_len", + ] + for i in range(buffers_per_store): + result.extend( + f"B{i}.{s}" + for s in [ + "acc_type.kRead", + "acc_type.kWrite", + "acc_type.kReadWrite", + "bytes", + "unique_bytes", + "lines", + "unique_lines", + "reuse_type.kLoopMultipleRead", + "reuse_type.kSerialMultipleReadWrite", + "reuse_type.kNoReuse", + "reuse_dis_iter", + "reuse_dis_bytes", + "reuse_ct", + "bytes_d_reuse_ct", + "unique_bytes_d_reuse_ct", + "lines_d_reuse_ct", + "unique_lines_d_reuse_ct", + "stride", + ] + ) + result.extend(f"arith_intensity_curve_{i}" for i in range(arith_intensity_curve_num_samples)) + result.extend( + [ + "alloc_size", + "alloc_prod", + "alloc_outer_prod", + "alloc_inner_prod", + "outer_prod", + "num_loops", + "auto_unroll_max_step", + ] + ) + # 57 + 18 * 5 + 10 + 4 + 3 + assert len(result) == N_FEATURES + return result + + +def _zip_feature(feature, names): + assert feature.ndim == 1 + assert feature.shape[0] == N_FEATURES + assert len(names) == N_FEATURES + return list(zip(names, feature)) + + +def _print_feature(feature, st, ed): # pylint: disable=invalid-name + named_feature = _zip_feature(feature, _feature_names()) + for k, v in named_feature[st:ed]: + print("\t", k, v) + + +def test_cpu_matmul(): + def _create_schedule(): + func = te.create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i_o, i_i = sch.split(i, factors=[None, 16]) # outer: 32 + j_o, j_i = sch.split(j, factors=[None, 8]) # outer: 64 + sch.reorder(i_o, j_o, k, j_i, i_i) + sch.vectorize(j_i) + sch.parallel(i_o) + sch.parallel(j_o) + sch.unroll(k) + return sch + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("llvm")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (1, N_FEATURES) + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[ + # float math ops + 0, 27, 27, 0, 0, 0, 0, + # int math ops + 0, 29, 29, 0, 0, 0, 0, + # bool/select ops + 0, 0, + ], + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[1.0, 3.169924, 3.169924, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[1.0, 9.002815, 9.002815, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[1.58496, 11.0007, 6.022368, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1, + 0, + 0, + 29, + 20, + 27, + 14, + 1, + 0, + 0, + 4.087463, + 7.0552826, + 3.169925, + 26, + 17, + 24, + 11.0007038, + 9.002815, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 0.0, + 1.0, + 29.0, + 20.000001907348633, + 27.0, + 14.00008773803711, + 1.0, + 0.0, + 0.0, + 7.011227130889893, + 9.250298500061035, + 9.002815246582031, + 20.000001907348633, + 11.000703811645508, + 18.0000057220459, + 5.044394016265869, + 9.002815246582031, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Buffer B + assert_allclose( + actual=f[93:111], + desired=[ + 1.0, + 0.0, + 0.0, + 29.0, + 20.000001907348633, + 19.000001907348633, + 14.00008773803711, + 1.0, + 0.0, + 0.0, + 1.0, + 3.700439691543579, + 4.087462902069092, + 25.0, + 16.000022888183594, + 15.000043869018555, + 10.001408576965332, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[ + 0.7097842693328857, + 0.7408391237258911, + 0.8750449419021606, + 0.9449487924575806, + 1.0148526430130005, + 1.0847564935684204, + 1.113688349723816, + 1.1394684314727783, + 1.2119636535644531, + 1.2971993684768677, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 20.000001907348633, + 18.0000057220459, + 1.0, + 27.0, + 27.0, + 2.5849626064300537, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + if __name__ == "__main__": - pass + test_cpu_matmul() + # test_cpu_fusion() + # test_gpu() From 4f782cf8f2c1e944a002eafc5c177bc30fd5e438 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 28 Nov 2021 18:01:05 -0800 Subject: [PATCH 4/5] bugfixes --- .../feature_extractor/per_store_feature.cc | 61 +- ...ule_feature_extractor_per_store_feature.py | 1187 ++++++++++++++++- 2 files changed, 1223 insertions(+), 25 deletions(-) diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index b69068f5c7..5c0b44f553 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "../utils.h" @@ -208,13 +209,32 @@ runtime::NDArray AsNDArray(const std::vector>& src) { namespace transform { -Pass SimplifyConstMatrix() { +Pass SimplifyForFeatureExtraction() { class Simplifier : private StmtExprMutator { public: static Stmt Run(Stmt stmt) { return Simplifier()(std::move(stmt)); } private: - PrimExpr VisitExpr_(const SelectNode* node) { return make_const(node->dtype, 1.0); } + PrimExpr VisitExpr_(const SelectNode* node) final { return make_const(node->dtype, 1.0); } + + PrimExpr VisitExpr_(const VarNode* var) final { + if (unit_vars_.count(GetRef(var))) { + return make_const(var->dtype, 0.0); + } + return GetRef(var); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (is_zero(loop->min) && is_one(loop->extent) && loop->kind == ForKind::kSerial && + loop->annotations.empty()) { + unit_vars_.insert(loop->loop_var); + return VisitStmt(loop->body); + } else { + return StmtExprMutator::VisitStmt_(loop); + } + } + + std::unordered_set unit_vars_; }; auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { PrimFuncNode* n = f.CopyOnWrite(); @@ -226,7 +246,7 @@ Pass SimplifyConstMatrix() { Sequential PassListForPerStoreFeature() { return Sequential({ - tir::transform::SimplifyConstMatrix(), + tir::transform::SimplifyForFeatureExtraction(), tir::transform::LowerCrossThreadReduction(), tir::transform::LowerInitBlock(), tir::transform::PlanAndUpdateBufferAllocationLocation(), @@ -234,6 +254,7 @@ Sequential PassListForPerStoreFeature() { tir::transform::UnifyThreadBinding(), tir::transform::CompactBufferAllocation(), tir::transform::LowerMatchBuffer(), + tir::transform::Simplify(), }); } @@ -305,6 +326,7 @@ struct LoopNest { if (const int64_t* extent = GetLoopIntExtent(loop)) { this->prod /= *extent; } + this->loops.pop_back(); } }; @@ -509,7 +531,7 @@ Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t prod_loop_exte Feature::ForKindFeature::ForKindFeature(const ForVec& loops) { if (loops.empty()) { this->num = 0; - this->prod = 1; + this->prod = 0; this->len = 0; this->pos = ForKindFeature::Pos::kPosNone; } else { @@ -627,7 +649,8 @@ struct Feature { } explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, - int64_t cache_line_bytes, IntVec* for_touched_bytes, arith::Analyzer* analyzer); + int64_t cache_line_bytes, IntVec* for_touched_bytes, + ForBufferMap* buffer_touched_under_loop, arith::Analyzer* analyzer); void Init(const BufferStoreNode* store, int n_loops); @@ -868,13 +891,13 @@ void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t cache_li } Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_t cache_line_bytes, - IntVec* for_touched_bytes, arith::Analyzer* analyzer) { + IntVec* for_touched_bytes, ForBufferMap* buffer_touched_under_loop, + arith::Analyzer* analyzer) { int n_loops = loop_nest.loops.size(); // Step 0. Initialize data structures this->Init(store, n_loops); // Step 1. Calculate region-related feature - ForBufferMap buffer_touched_under_loop; - this->SetRegion(loop_nest, for_touched_bytes, &buffer_touched_under_loop, analyzer); + this->SetRegion(loop_nest, for_touched_bytes, buffer_touched_under_loop, analyzer); // Step 2. Calculate stride-related feature for (auto& feature : sub_features) { feature.SetStride(loop_nest); @@ -889,7 +912,7 @@ Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_ } } for (auto& feature : sub_features) { - feature.SetReuse(loop_nest, top_loop_touch_bytes, buffer_touched_under_loop); + feature.SetReuse(loop_nest, top_loop_touch_bytes, *buffer_touched_under_loop); } // Step 4. Calculate rest of the features for (auto& feature : sub_features) { @@ -1091,7 +1114,9 @@ class PerStoreFeatureCollector : private StmtVisitor { } void VisitStmt_(const BufferStoreNode* store) final { - LOG(INFO) << "Visit BufferStore:\n" << GetRef(store); + if (store->value->IsInstance() || store->value->IsInstance()) { + return; + } const BufferNode* buffer = store->buffer.get(); Feature& feature = buffer_features_[buffer]; if (feature.buffer == nullptr) { @@ -1100,7 +1125,8 @@ class PerStoreFeatureCollector : private StmtVisitor { } feature.group1 = std::make_unique(store, loop_nest_, is_gpu_); feature.group2 = std::make_unique(store, loop_nest_, cache_line_bytes_, - &for_touched_bytes_, &analyzer_); + &for_touched_bytes_, // + &buffer_touched_under_loop_, &analyzer_); feature.group3 = std::make_unique(arith_intensity_curve_num_samples_, // loop_nest_, for_touched_bytes_, feature.group1->arith_ops); @@ -1131,6 +1157,7 @@ class PerStoreFeatureCollector : private StmtVisitor { arith::Analyzer analyzer_; LoopNest loop_nest_ = {}; IntVec for_touched_bytes_ = {}; + ForBufferMap buffer_touched_under_loop_ = {}; std::unordered_map buffer_features_ = {}; }; @@ -1157,7 +1184,6 @@ class PerStoreFeatureNode : public FeatureExtractorNode { void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { static transform::Sequential passes = tir::transform::PassListForPerStoreFeature(); mod = passes(std::move(mod)); - LOG(INFO) << "mod =\n" << tir::AsTVMScript(mod, "T"); std::vector features = tir::PerStoreFeatureCollector::Collect( is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod); int n_features = features.size(); @@ -1172,18 +1198,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode { feature.group4->Export(&result, feature.group5->outer_prod); feature.group5->Export(&result); ICHECK_EQ(static_cast(result.size()), feature_vector_length); - DebugFeature(feature); - } - } - - void DebugFeature(const tir::Feature& feature) { - const tir::group2::Feature& f2 = *feature.group2; - std::ostringstream os; - os << "Feature(" << feature.buffer->name << "):"; - for (const tir::group2::Feature::SubFeature& sub_f : f2.sub_features) { - os << " " << sub_f.buffer->name; } - LOG(INFO) << os.str(); } Array ExtractFrom(const TuneContext& tune_context, diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py index 163fd81326..210bc01499 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -21,6 +21,7 @@ import tvm from tvm import meta_schedule as ms, te, tir from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T N_FEATURES = 164 @@ -347,7 +348,1189 @@ def _create_schedule(): ) +def test_cpu_fusion(): + # pylint: disable=all + @T.prim_func + def func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [64, 32], dtype="float32") + B = T.match_buffer(b, [64, 32], dtype="float32") + C = T.match_buffer(c, [64, 32], dtype="float32") + for i, j in T.grid(64, 32): # type: ignore + with T.block(): + T.reads([A[i, j], B[i, j]]) # type: ignore + T.writes([B[i, j], C[i, j]]) # type: ignore + with T.block("B"): + T.reads([A[i, j]]) # type: ignore + T.writes([B[i, j]]) # type: ignore + B[i, j] = A[i, j] # type: ignore + with T.block("C"): + T.reads([B[i, j]]) # type: ignore + T.writes([C[i, j]]) # type: ignore + C[i, j] = B[i, j] # type: ignore + + # pylint: enable=all + + def _create_schedule(): + return tir.Schedule(func, debug_mask="all") + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("llvm")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (2, N_FEATURES) + ## Features for BufferStore(B) + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[0.0] * 16, + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer B + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 13.000176, + 11.000703811645508, + 1.0, + 11.000703811645508, + 11.000703811645508, + 1.5849624872207642, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ## Features for BufferStore(C) + f = feature[1] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[0.0] * 16, + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 1.0, + 0.0, + 1.0, + 4.087462902069092, + 1.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 13.000176429748535, + 11.000703811645508, + 1.0, + 11.000703811645508, + 11.000703811645508, + 1.5849624872207642, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + + +def test_gpu(): + def _create_schedule(): + n = m = k = 512 + func = te.create_prim_func(te_workload.matmul(n=n, m=m, k=k)) + sch = tir.Schedule(func, debug_mask="all") + c = sch.get_block("C") + c_local = sch.cache_write(c, 0, "local") + i, j, k = sch.get_loops(c) + # pylint: disable=invalid-name + i0, i1, i2, i3, i4 = sch.split(i, factors=[None, 1, 16, 32, 1]) # outer: 1 + j0, j1, j2, j3, j4 = sch.split(j, factors=[None, 4, 1, 1, 16]) # outer: 8 + k0, k1, k2 = sch.split(k, factors=[None, 1, 2]) # outer: 256 + # pylint: enable=invalid-name + # fmt: off + sch.reorder( + i0, j0, # S + i1, j1, # S + i2, j2, # S + k0, # R + k1, # R + i3, j3, # S + k2, # R + i4, j4, # S + ) + # fmt: on + # thread binding + i0_j0 = sch.fuse(i0, j0) + i1_j1 = sch.fuse(i1, j1) + i2_j2 = sch.fuse(i2, j2) + sch.bind(i0_j0, "blockIdx.x") + sch.bind(i1_j1, "vthread.x") + sch.bind(i2_j2, "threadIdx.x") + # fusion + sch.reverse_compute_at(c_local, i2_j2) + # cache read 'A' + a_shared = sch.cache_read(c, 1, "shared") + sch.compute_at(a_shared, k0) + _, _, _, _, a_i, a_j = sch.get_loops(a_shared) + a_ij = sch.fuse(a_i, a_j) + _, a_j = sch.split(a_ij, factors=[None, 16]) # outer: 64 + sch.bind(a_j, "threadIdx.x") + # cache read 'B' + b_shared = sch.cache_read(c, 2, "shared") + sch.compute_at(b_shared, k0) + _, _, _, _, b_i, b_j = sch.get_loops(b_shared) + b_ij = sch.fuse(b_i, b_j) + _, b_j = sch.split(b_ij, factors=[None, 16]) # outer: 8 + sch.bind(b_j, "threadIdx.x") + # auto unroll + sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tir.IntImm("int32", 1024)) + sch.annotate(i0_j0, "pragma_unroll_explicit", tir.IntImm("int32", 1)) + return sch + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("cuda")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (4, N_FEATURES) + ### Check feature[0]: BufferStore(A_shared) <= A[...] + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 24.000000085991324, + 24.000000085991324, + 24.000000085991324, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 25.000000042995662, + 20.000001375860553, + 23.00000017198264, + 14.000088052430122, + 1.0, + 0.0, + 0.0, + 18.00000550343433, + 20.00562591970089, + 2.321928094887362, + 23.00000017198264, + 18.00000550343433, + 21.000000687930438, + 12.0003521774803, + 12.0003521774803, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer A.shared + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 25.000000042995662, + 12.0003521774803, + 23.00000017198264, + 9.002815015607053, + 1.0, + 0.0, + 0.0, + 6.022367813028454, + 11.98049663618346, + 8.005624549193879, + 17.000011006847668, + 4.087462841250339, + 15.000044026886828, + 1.584962500721156, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 12.0003521774803, + 27.000000010748916, + 17.000011006847668, + 6.022367813028454, + 23.00000017198264, + 2.584962500721156, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[1]: BufferStore(B_shared) <= B[...] + f = feature[1] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 22.00000034396526, + 22.00000034396526, + 21.000000687930438, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 22.00000034396526, + 20.000001375860553, + 20.000001375860553, + 14.000088052430122, + 1.0, + 0.0, + 0.0, + 15.000044026886828, + 20.17555076886471, + 2.321928094887362, + 20.000001375860553, + 18.00000550343433, + 18.00000550343433, + 12.0003521774803, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer B.shared + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 22.00000034396526, + 9.002815015607053, + 20.000001375860553, + 3.169925001442312, + 1.0, + 0.0, + 0.0, + 3.169925001442312, + 10.001408194392809, + 8.005624549193879, + 14.000088052430122, + 1.584962500721156, + 12.0003521774803, + 0.044394119358453436, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 9.002815015607053, + 24.000000085991324, + 17.000011006847668, + 3.169925001442312, + 20.000001375860553, + 2.584962500721156, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[2]: BufferStore(C_local) <= C_local[...] + A_shared[...] * B_shared[...] + f = feature[2] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 27.000000010748916, + 27.000000010748916, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 28.000000005374456, + 28.000000005374456, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B.shared + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 29.00000000268723, + 9.002815015607053, + 23.00000017198264, + 3.169925001442312, + 1.0, + 0.0, + 0.0, + 5.044394119358453, + 7.651051691178929, + 5.044394119358453, + 24.000000085991324, + 4.087462841250339, + 18.00000550343433, + 0.32192809488736235, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C.local + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 0.0, + 1.0, + 29.00000000268723, + 11.000704269011246, + 23.00000017198264, + 5.044394119358453, + 1.0, + 0.0, + 0.0, + 4.087462841250339, + 7.05528243550119, + 1.584962500721156, + 28.000000005374456, + 10.001408194392809, + 22.00000034396526, + 4.087462841250339, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Buffer A.shared + assert_allclose( + actual=f[93:111], + desired=[ + 1.0, + 0.0, + 0.0, + 29.00000000268723, + 12.0003521774803, + 19.00000275171979, + 9.002815015607053, + 1.0, + 0.0, + 0.0, + 1.0, + 3.700439718141092, + 4.087462841250339, + 25.000000042995662, + 8.005624549193879, + 15.000044026886828, + 5.044394119358453, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[ + 0.7097842504665767, + 0.7548801745187567, + 0.8775907547541741, + 0.9957389916154509, + 1.2446737395193135, + 1.493608487423176, + 1.7093103019954263, + 1.8031580276850985, + 1.9841832691827785, + 2.204648076869754, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 11.000704269011246, + 18.00000550343433, + 9.002815015607053, + 18.00000550343433, + 27.000000010748916, + 3.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[3]: BufferStore(C) <= C_local[...] + f = feature[3] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer C + assert_allclose( + actual=f[57:75], + desired=[ + 0.0, + 1.0, + 0.0, + 20.000001375860553, + 20.000001375860553, + 14.000088052430122, + 14.000088052430122, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 21.000000687930438, + 21.000000687930438, + 15.000044026886828, + 15.000044026886828, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C.local + assert_allclose( + actual=f[75:93], + desired=[ + 1.0, + 0.0, + 0.0, + 20.000001375860553, + 11.000704269011246, + 14.000088052430122, + 5.044394119358453, + 1.0, + 0.0, + 0.0, + 9.002815015607053, + 12.0003521774803, + 4.087462841250339, + 16.00002201361136, + 7.011227255423254, + 10.001408194392809, + 1.584962500721156, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 20.000001375860553, + 18.00000550343433, + 1.0, + 18.00000550343433, + 18.00000550343433, + 2.584962500721156, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + + if __name__ == "__main__": test_cpu_matmul() - # test_cpu_fusion() - # test_gpu() + test_cpu_fusion() + test_gpu() From f8f195d281877287c4e2c81d6cb30a47feec644e Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 28 Nov 2021 19:15:45 -0800 Subject: [PATCH 5/5] docs --- .../feature_extractor/per_store_feature.py | 27 ++- .../feature_extractor/per_store_feature.cc | 189 +++++++++++------- src/tir/schedule/utils.h | 13 -- 3 files changed, 140 insertions(+), 89 deletions(-) diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py index a5283d2fce..30572ed5b9 100644 --- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -23,14 +23,39 @@ from .feature_extractor import FeatureExtractor +# /*! +# * \brief Create a feature extractor that extracts features from each BufferStore +# * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if +# * necessary. +# * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity +# * curve. +# * \param cache_line_bytes The number of bytes in a cache line. +# * \return The feature extractor created. +# */ + + @register_object("meta_schedule.PerStoreFeature") class PerStoreFeature(FeatureExtractor): - """PerStoreFeature extracts one feature vector per BufferStoreNode""" + """PerStoreFeature extracts one feature vector per BufferStoreNode + + Parameters + ---------- + buffers_per_store : int + The number of buffers in each BufferStore; Pad or truncate if necessary. + arith_intensity_curve_num_samples : int + The number of samples used in the arithmetic intensity curve. + cache_line_bytes : int + The number of bytes in a cache line. + """ buffers_per_store: int + """The number of buffers in each BufferStore; Pad or truncate if necessary.""" arith_intensity_curve_num_samples: int # pylint: disable=invalid-name + """The number of samples used in the arithmetic intensity curve.""" cache_line_bytes: int + """The number of bytes in a cache line.""" feature_vector_length: int + """Length of the feature vector.""" def __init__( self, diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 5c0b44f553..2081976d2c 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -26,51 +27,35 @@ #include "../utils.h" -namespace std { - -std::ostream& operator<<(std::ostream& os, const std::vector& vec) { - tvm::tir::PrintVector(vec, os, [&os](int64_t i) { os << i; }); - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::vector& vec) { - tvm::tir::PrintVector(vec, os, [&os](double i) { os << i; }); - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::vector& vec) { - tvm::tir::PrintVector(vec, os, [&os](const tvm::PrimExpr& i) { os << i; }); - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::vector>& vec) { - tvm::tir::PrintVector(vec, os, [&os](const std::vector& i) { os << i; }); - return os; -} - -} // namespace std - namespace tvm { namespace tir { using support::NDIntSet; +/*! \brief Type for multi-dimensional index */ using MultiIndex = std::vector; +/*! \brief Vector of int64_t */ using IntVec = std::vector; +/*! \brief Vector of for loops */ using ForVec = std::vector; +/*! + * \brief An unordered_map for (for, buffer) => V + * \tparam V The value type + */ template using ForBufferMap = std::unordered_map>; -inline double slog(double x) { - if (x < 0) { - x = -x; - } - return std::log2(x + 1); -} +/*! \brief Given x, compute log2(|x| + 1) */ +inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x + 1); } namespace utils { +/*! + * \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if it exists + * \param loop The loop to be checked + * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if it does not exist + */ int64_t GetPragmaAutoUnroll(const ForNode* loop) { if (Optional auto_unroll = GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { return auto_unroll.value()->value; @@ -78,16 +63,15 @@ int64_t GetPragmaAutoUnroll(const ForNode* loop) { return -1; } -int64_t ProdLoopExtent(const ForVec& loops) { - int64_t prod = 1; - for (const ForNode* loop : loops) { - if (const int64_t* extent = GetLoopIntExtent(loop)) { - prod *= *extent; - } - } - return prod; -} - +/*! + * \brief Given a list of loops, return the extent of the first loop if the list is not empty, + * and the first loop has constant extent. Otherwise returns the default value given + * \param loops The list of loops to be checked + * \param default_value The default value to be returned if the list is empty or the first loop + * does not have constant extent + * \return The extent of the first loop if the list is not empty, or the first loop has constant + * extent. Otherwise returns the default value + */ int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) { if (!loops.empty()) { if (const int64_t* extent = GetLoopIntExtent(loops[0])) { @@ -97,8 +81,16 @@ int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) { return default_value; } -IntVec UnionAndGetRelaxedSize(const std::vector& multi_indices, int64_t* numel, - arith::Analyzer* analyzer) { +/*! + * \brief Relax each of the multi-indexing pattern according to the domains bound in the analyzer, + * and then union them into a single region + * \param multi_index_pattern A list of multi-index pattern to be relaxed + * \param numel The size of the single region after union + * \param analyzer The analyzer that contains the domain information + * \return The relaxed and unioned region + */ +IntVec RelaxAndUnion(const std::vector& multi_indices, int64_t* numel, + arith::Analyzer* analyzer) { if (multi_indices.empty()) { return {}; } @@ -119,6 +111,13 @@ IntVec UnionAndGetRelaxedSize(const std::vector& multi_indices, int6 return access_shape; } +/*! + * \brief Given a list of multi-index pattern, return the minimal stride of a variable on it + * \param multi_indices The list of multi-index pattern + * \param buffer_stride The stride of the buffer + * \param var The variable to be checked + * \return The minimal stride of the variable on the multi-index pattern + */ int64_t GetVarStride(const std::vector& multi_indices, const IntVec& buffer_stride, const Var& var) { class CoefficientExtractor : private ExprVisitor { @@ -188,13 +187,18 @@ int64_t GetVarStride(const std::vector& multi_indices, const IntVec& return (result == kNotFound) ? 0 : result; } +/*! + * \brief Converts a 2-dimensional STL vector to a TVM NDArray + * \param src The source 2-dimensional STL vector + * \return The converted TVM NDArray + */ runtime::NDArray AsNDArray(const std::vector>& src) { ICHECK(!src.empty()); int n = src.size(); int m = src[0].size(); runtime::NDArray tgt = runtime::NDArray::Empty( /*shape=*/{n, m}, - /*dtype=*/DLDataType{kDLFloat, 64, 1}, // + /*dtype=*/DLDataType{kDLFloat, 64, 1}, /*ctx=*/DLDevice{kDLCPU, 0}); double* data = static_cast(tgt->data); for (const std::vector& row : src) { @@ -209,6 +213,10 @@ runtime::NDArray AsNDArray(const std::vector>& src) { namespace transform { +/*! + * \brief Create a pass that simplifies the IR for feature extraction + * \return The pass created + */ Pass SimplifyForFeatureExtraction() { class Simplifier : private StmtExprMutator { public: @@ -244,6 +252,10 @@ Pass SimplifyForFeatureExtraction() { return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyConstMatrix", {}); } +/*! + * \brief Create a list of passes that preprocesses the IR for feature extraction + * \return The list of passes created + */ Sequential PassListForPerStoreFeature() { return Sequential({ tir::transform::SimplifyForFeatureExtraction(), @@ -260,28 +272,35 @@ Sequential PassListForPerStoreFeature() { } // namespace transform +/*! \brief A data structure managing loop nests */ struct LoopNest { - int64_t prod = 1; - ForVec loops; - IntVec auto_unroll; - ForVec parallel; - ForVec vectorize; - ForVec unroll; - ForVec blockIdx_x; - ForVec blockIdx_y; - ForVec blockIdx_z; - ForVec threadIdx_x; - ForVec threadIdx_y; - ForVec threadIdx_z; - ForVec vthread; - - ForVec* Push(const ForNode* loop, int64_t auto_unroll_attr) { + int64_t prod = 1; // The product of the extents of all the loops + ForVec loops; // All the loops + IntVec auto_unroll; // The loops with auto unroll pragma + ForVec parallel; // The loops whose ForKind are kParallel + ForVec vectorize; // The loops whose ForKind are kVectorized + ForVec unroll; // The loops whose ForKind are kUnrolled + ForVec blockIdx_x; // The loops whose ForKind are kThreadBinding to blockIdx.x + ForVec blockIdx_y; // The loops whose ForKind are kThreadBinding to blockIdx.y + ForVec blockIdx_z; // The loops whose ForKind are kThreadBinding to blockIdx.z + ForVec threadIdx_x; // The loops whose ForKind are kThreadBinding to threadIdx.x + ForVec threadIdx_y; // The loops whose ForKind are kThreadBinding to threadIdx.y + ForVec threadIdx_z; // The loops whose ForKind are kThreadBinding to threadIdx.z + ForVec vthread; // The loops whose ForKind are kThreadBinding to vthread.* + + /*! + * \brief Push a new loop into the loop nest + * \param loop The loop to be pushed + * \param auto_unroll_attr The auto unroll attribute of the loop + * \return A list of for loops that the loop is bound to + */ + ForVec* Push(const ForNode* loop, int64_t* auto_unroll_attr) { if (const int64_t* extent = GetLoopIntExtent(loop)) { this->prod *= *extent; } this->loops.push_back(loop); - if (auto_unroll_attr > 0) { - this->auto_unroll.push_back(auto_unroll_attr); + if ((*auto_unroll_attr = utils::GetPragmaAutoUnroll(loop)) > 0) { + this->auto_unroll.push_back(*auto_unroll_attr); } ForVec* ref_loops = nullptr; if (loop->kind == ForKind::kParallel) { @@ -316,6 +335,12 @@ struct LoopNest { return ref_loops; } + /*! + * \brief Pop the last loop from the loop nest + * \param loop The loop to be popped + * \param ref_loops The list of for loops that the loop is bound to + * \param auto_unroll_attr The auto unroll attribute of the loop + */ void Pop(const ForNode* loop, ForVec* ref_loops, int auto_unroll_attr) { if (ref_loops) { ref_loops->pop_back(); @@ -334,7 +359,9 @@ struct LoopNest { namespace group1 { +/*! \brief Group 1 features */ struct Feature { + /*! \brief Arithmetic features */ struct ArithOps { // Float-point arithmetic features int64_t float_mad = 0; // The number of float MAD (Multiply–add) ops @@ -373,6 +400,7 @@ struct Feature { } }; + /*! \brief Loop binding features */ struct ForKindFeature { enum class Pos : int { kPosNone = 0, // Does not have this kind of annotation @@ -412,11 +440,11 @@ struct Feature { } }; - ArithOps arith_ops; - ForKindFeature vectorize; - ForKindFeature unroll; - ForKindFeature parallel; - bool is_gpu = false; + ArithOps arith_ops; // Arithmetic features + ForKindFeature vectorize; // Loop binding features: kVectorize + ForKindFeature unroll; // Loop binding features: kUnroll + ForKindFeature parallel; // Loop binding features: kParallel + bool is_gpu = false; // If the program is running on GPU int64_t blockIdx_x_len = 1; // The length of blockIdx.x int64_t blockIdx_y_len = 1; // The length of blockIdx.y int64_t blockIdx_z_len = 1; // The length of blockIdx.z @@ -537,9 +565,14 @@ Feature::ForKindFeature::ForKindFeature(const ForVec& loops) { } else { const int64_t* last_loop_extent = GetLoopIntExtent(loops.back()); this->num = loops.size(); - this->prod = utils::ProdLoopExtent(loops); this->len = last_loop_extent ? *last_loop_extent : 1; this->pos = ForKindFeature::Pos::kPosMixed; + int64_t& prod = this->prod = 1; + for (const ForNode* loop : loops) { + if (const int64_t* extent = GetLoopIntExtent(loop)) { + prod *= *extent; + } + } } } @@ -547,6 +580,7 @@ Feature::ForKindFeature::ForKindFeature(const ForVec& loops) { namespace group2 { +/*! \brief Group 2 features */ struct Feature { enum class AccessType : int { kRead = 0, // The buffer is read but not written @@ -733,7 +767,7 @@ void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes, // Note: `feature.access_shape` for `i == 0` is the only one preserved, // while others are discarded int64_t numel = 1; - feature.access_shape = utils::UnionAndGetRelaxedSize(feature.multi_indices, &numel, analyzer); + feature.access_shape = utils::RelaxAndUnion(feature.multi_indices, &numel, analyzer); feature.loop_accessed_numel[i][buffer] = numel; touched_bytes += numel * buffer->dtype.bytes(); (*buffer_touched_under_loop)[loop][buffer].push_back(numel); @@ -934,6 +968,7 @@ Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_ namespace group3 { +/*! \brief Group 3 feature */ struct Feature { std::vector arith_intensity_curve; @@ -1001,6 +1036,7 @@ struct Feature { namespace group4 { +/*! \brief Group 4 feature */ struct Feature { int64_t alloc_size = 0; // The size of allocated buffer in bytes int64_t alloc_prod = 0; // alloc_outer_prod * alloc_inner_prod @@ -1035,6 +1071,7 @@ struct Feature { namespace group5 { +/*! \brief Group 5 feature */ struct Feature { int64_t outer_prod; // The product of lengths of outer loops int num_loops; // The number of outer loops @@ -1060,6 +1097,7 @@ struct Feature { } // namespace group5 +/*! \brief The feature extracted */ struct Feature { const BufferNode* buffer = nullptr; int buffer_order = -1; @@ -1072,6 +1110,7 @@ struct Feature { bool operator<(const Feature& other) const { return buffer_order < other.buffer_order; } }; +/*! \brief The main feature extractor */ class PerStoreFeatureCollector : private StmtVisitor { public: static std::vector Collect(bool is_gpu, int64_t cache_line_bytes, @@ -1107,8 +1146,8 @@ class PerStoreFeatureCollector : private StmtVisitor { private: void VisitStmt_(const ForNode* loop) final { - int64_t auto_unroll = utils::GetPragmaAutoUnroll(loop); - ForVec* for_vec = loop_nest_.Push(loop, auto_unroll); + int64_t auto_unroll; + ForVec* for_vec = loop_nest_.Push(loop, &auto_unroll); StmtVisitor::VisitStmt_(loop); loop_nest_.Pop(loop, for_vec, auto_unroll); } @@ -1124,12 +1163,12 @@ class PerStoreFeatureCollector : private StmtVisitor { feature.buffer_order = buffer_features_.size(); } feature.group1 = std::make_unique(store, loop_nest_, is_gpu_); - feature.group2 = std::make_unique(store, loop_nest_, cache_line_bytes_, - &for_touched_bytes_, // - &buffer_touched_under_loop_, &analyzer_); - feature.group3 = std::make_unique(arith_intensity_curve_num_samples_, // - loop_nest_, for_touched_bytes_, - feature.group1->arith_ops); + feature.group2 = + std::make_unique(store, loop_nest_, cache_line_bytes_, &for_touched_bytes_, + &buffer_touched_under_loop_, &analyzer_); + feature.group3 = + std::make_unique(arith_intensity_curve_num_samples_, loop_nest_, + for_touched_bytes_, feature.group1->arith_ops); feature.group5 = std::make_unique(loop_nest_); } diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 2c2a7666f5..091266ee38 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -105,19 +105,6 @@ namespace tir { << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") -template -void PrintVector(const std::vector& vec, std::ostream& os, FPrint print) { - int n = vec.size(); - os << "["; - for (int i = 0; i < n; ++i) { - if (i != 0) { - os << ", "; - } - print(vec[i]); - } - os << "]"; -} - /*! * \brief Convert an array of loop StmtSRefs to an array of loops * \param loop_srefs The loop StmtSRefs to be converted