diff --git a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py index 1cf556ddd0..0182456099 100644 --- a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py +++ b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py @@ -313,7 +313,7 @@ def convert_ndarray(dst_dtype, array): print(str(e).split("\n")[-1]) ###################################################################### -# When we attempt to run the model, we get a familiar error telling us that more funcions need to be registerd for myfloat. +# When we attempt to run the model, we get a familiar error telling us that more functions need to be registerd for myfloat. # # Because this is a neural network, many more operations are required. # Here, we register all the needed functions: diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 6c72cbeafd..1671d8fdc8 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -346,6 +346,8 @@ Array> SubspaceDivide(const Array& bindings, const Array& sub_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer); +PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_ITER_AFFINE_MAP_H_ diff --git a/include/tvm/auto_scheduler/cost_model.h b/include/tvm/auto_scheduler/cost_model.h index f7a27895a7..a52c6797b6 100755 --- a/include/tvm/auto_scheduler/cost_model.h +++ b/include/tvm/auto_scheduler/cost_model.h @@ -122,11 +122,11 @@ class RandomModel : public CostModel { * This class will call functions defined in the python */ class PythonBasedModelNode : public CostModelNode { public: - /*! \brief Pointer to the update funcion in python */ + /*! \brief Pointer to the update function in python */ PackedFunc update_func; - /*! \brief Pointer to the predict funcion in python */ + /*! \brief Pointer to the predict function in python */ PackedFunc predict_func; - /*! \brief Pointer to the predict funcion in python */ + /*! \brief Pointer to the predict function in python */ PackedFunc predict_stage_func; void Update(const Array& inputs, const Array& results) final; diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h index 841b6b9530..20a93e280b 100755 --- a/include/tvm/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -236,7 +236,7 @@ class MeasureCallback : public ObjectRef { * This class will call functions defined in the python */ class PythonBasedMeasureCallbackNode : public MeasureCallbackNode { public: - /*! \brief Pointer to the callback funcion in python */ + /*! \brief Pointer to the callback function in python */ PackedFunc callback_func; void Callback(const SearchPolicy& policy, const Array& inputs, diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index b809843f41..2b80945915 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -32,10 +32,13 @@ class BuilderInputNode : public runtime::Object { IRModule mod; /*! \brief The target to be built for. */ Target target; + /*! \brief Parameters for Relay build module. */ + Optional> params; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); v->Visit("target", &target); + v->Visit("params", ¶ms); } static constexpr const char* _type_key = "meta_schedule.BuilderInput"; @@ -52,8 +55,10 @@ class BuilderInput : public runtime::ObjectRef { * \brief Constructor of BuilderInput. * \param mod The IRModule to be built. * \param target The target to be built for. + * \param params Parameters for Relay build module. */ - TVM_DLL explicit BuilderInput(IRModule mod, Target target); + TVM_DLL explicit BuilderInput(IRModule mod, Target target, + Optional> params = NullOpt); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h new file mode 100644 index 0000000000..b05dc3c118 --- /dev/null +++ b/include/tvm/meta_schedule/cost_model.h @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_COST_MODEL_H_ +#define TVM_META_SCHEDULE_COST_MODEL_H_ + +#include + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Cost model. */ +class CostModelNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~CostModelNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Load the cost model from given file location. + * \param path The file path. + */ + virtual void Load(const String& path) = 0; + + /*! + * \brief Save the cost model to given file location. + * \param path The file path. + */ + virtual void Save(const String& path) = 0; + + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + */ + virtual void Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) = 0; + + /*! + * \brief Predict the normalized score (the larger the better) of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \return The predicted normalized score. + */ + virtual std::vector Predict(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.CostModel"; + TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); +}; + +/*! \brief The cost model with customized methods on the python-side. */ +class PyCostModelNode : public CostModelNode { + public: + /*! + * \brief Load the cost model from given file location. + * \param path The file path. + */ + using FLoad = runtime::TypedPackedFunc; + /*! + * \brief Save the cost model to given file location. + * \param path The file path. + */ + using FSave = runtime::TypedPackedFunc; + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + * \return Whether cost model was updated successfully. + */ + using FUpdate = runtime::TypedPackedFunc&, + const Array&)>; + /*! + * \brief Predict the running results of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param p_addr The address to save the the estimated running results. + */ + using FPredict = runtime::TypedPackedFunc&, + void* p_addr)>; + /*! + * \brief Get the cost model as string with name. + * \return The string representation of the cost model. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Load` function. */ + FLoad f_load; + /*! \brief The packed function to the `Save` function. */ + FSave f_save; + /*! \brief The packed function to the `Update` function. */ + FUpdate f_update; + /*! \brief The packed function to the `Predict` function. */ + FPredict f_predict; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_load` is not visited + // `f_save` is not visited + // `f_update` is not visited + // `f_predict` is not visited + // `f_as_string` is not visited + } + + void Load(const String& path) { + ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; + f_load(path); + } + + void Save(const String& path) { + ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; + f_save(path); + } + void Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) { + ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; + f_update(tune_context, candidates, results); + } + + std::vector Predict(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; + std::vector result(candidates.size(), 0.0); + f_predict(tune_context, candidates, result.data()); + return result; + } + + static constexpr const char* _type_key = "meta_schedule.PyCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); +}; + +/*! + * \brief Managed reference to CostModelNode + * \sa CostModelNode + */ +class CostModel : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_load The packed function of `Load`. + * \param f_save The packed function of `Save`. + * \param f_update The packed function of `Update`. + * \param f_predict The packed function of `Predict`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_COST_MODEL_H_ diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 60c6898f00..307ec309c0 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -155,6 +155,12 @@ class DatabaseNode : public runtime::Object { public: /*! \brief Default destructor */ virtual ~DatabaseNode() = default; + /*! + * \brief Check if the database has the given workload. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + virtual bool HasWorkload(const IRModule& mod) = 0; /*! * \brief Look up or add workload to the database if missing. * \param mod The IRModule to be searched for or added. @@ -186,6 +192,12 @@ class DatabaseNode : public runtime::Object { /*! \brief The database with customized methods on the python-side. */ class PyDatabaseNode : public DatabaseNode { public: + /*! + * \brief The function type of `HasWorkload` method. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + using FHasWorkload = runtime::TypedPackedFunc; /*! * \brief The function type of `CommitWorkload` method. * \param mod The IRModule to be searched for or added. @@ -210,6 +222,8 @@ class PyDatabaseNode : public DatabaseNode { */ using FSize = runtime::TypedPackedFunc; + /*! \brief The packed function to the `HasWorkload` function. */ + FHasWorkload f_has_workload; /*! \brief The packed function to the `CommitWorkload` function. */ FCommitWorkload f_commit_workload; /*! \brief The packed function to the `CommitTuningRecord` function. */ @@ -224,12 +238,18 @@ class PyDatabaseNode : public DatabaseNode { // so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. // + // `f_has_workload` is not visited // `f_commit_workload` is not visited // `f_commit_tuning_record` is not visited // `f_get_top_k` is not visited // `f_size` is not visited } + bool HasWorkload(const IRModule& mod) final { + ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!"; + return f_has_workload(mod); + } + Workload CommitWorkload(const IRModule& mod) final { ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!"; return f_commit_workload(mod); @@ -271,13 +291,15 @@ class Database : public runtime::ObjectRef { bool allow_missing); /*! * \brief Create a database with customized methods on the python-side. + * \param f_has_workload The packed function of `HasWorkload`. * \param f_commit_workload The packed function of `CommitWorkload`. * \param f_commit_tuning_record The packed function of `CommitTuningRecord`. * \param f_get_top_k The packed function of `GetTopK`. * \param f_size The packed function of `Size`. * \return The created database. */ - TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, + TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, + PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size); diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h new file mode 100644 index 0000000000..ee5d94c13c --- /dev/null +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ +#define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Extractor for features from measure candidates for use in cost model. */ +class FeatureExtractorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~FeatureExtractorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + virtual Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.FeatureExtractor"; + TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object); +}; + +/*! \brief The feature extractor with customized methods on the python-side. */ +class PyFeatureExtractorNode : public FeatureExtractorNode { + public: + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + using FExtractFrom = runtime::TypedPackedFunc( + const TuneContext& tune_context, const Array& candidates)>; + /*! + * \brief Get the feature extractor as string with name. + * \return The string of the feature extractor. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `ExtractFrom` function. */ + FExtractFrom f_extract_from; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_extract_from` is not visited + // `f_as_string` is not visited + } + + Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; + return f_extract_from(tune_context, candidates); + } + + static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); +}; + +/*! + * \brief Managed reference to FeatureExtractorNode + * \sa FeatureExtractorNode + */ +class FeatureExtractor : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor that extracts features from each BufferStore + * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if + * necessary. + * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity + * curve. + * \param cache_line_bytes The number of bytes in a cache line. + * \return The feature extractor created. + */ + TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5, + int arith_intensity_curve_num_samples = 10, + int cache_line_bytes = 64); + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_extract_from The packed function of `ExtractFrom`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static FeatureExtractor PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, + PyFeatureExtractorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h new file mode 100644 index 0000000000..f0763a468c --- /dev/null +++ b/include/tvm/meta_schedule/measure_callback.h @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ +#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +class TaskScheduler; + +/*! \brief Rules to apply after measure results is available. */ +class MeasureCallbackNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MeasureCallbackNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Apply a measure callback rule with given arguments. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builder_results The builder results by building the measure candidates. + * \param runner_results The runner results by running the built measure candidates. + */ + virtual void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const Array& measure_candidates, // + const Array& builder_results, // + const Array& runner_results) = 0; + + static constexpr const char* _type_key = "meta_schedule.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; + +/*! \brief The measure callback with customized methods on the python-side. */ +class PyMeasureCallbackNode : public MeasureCallbackNode { + public: + /*! + * \brief Apply a measure callback to the given schedule. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + using FApply = + runtime::TypedPackedFunc& measure_candidates, // + const Array& builds, // + const Array& results)>; + /*! + * \brief Get the measure callback function as string with name. + * \return The string of the measure callback function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) final { + ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; + return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results); + } + + static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to MeasureCallbackNode + * \sa MeasureCallbackNode + */ +class MeasureCallback : public runtime::ObjectRef { + public: + /*! + * \brief Create a measure callback that adds the measurement results into the database + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback AddToDatabase(); + /*! + * \brief Create a measure callback that removes the build artifacts from the disk + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback RemoveBuildArtifact(); + /*! + * \brief Create a measure callback that echos the statistics of the tuning process to the console + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback EchoStatistics(); + /*! + * \brief Create a measure callback that updates the cost model with measurement result. + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback UpdateCostModel(); + /*! + * \brief Create a measure callback with customized methods on the python-side. + * \param f_apply The packed function of `Apply`. + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, + PyMeasureCallbackNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h new file mode 100644 index 0000000000..3595b4191f --- /dev/null +++ b/include/tvm/meta_schedule/mutator.h @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MUTATOR_H_ +#define TVM_META_SCHEDULE_MUTATOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Mutator is designed to mutate the trace to explore the design space. */ +class MutatorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MutatorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \param rand_state The random state for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + virtual Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) = 0; + + static constexpr const char* _type_key = "meta_schedule.Mutator"; + TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); +}; + +/*! \brief The mutator with customized methods on the python-side. */ +class PyMutatorNode : public MutatorNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + using FApply = runtime::TypedPackedFunc( + const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; + /*! + * \brief Get the mutator as string with name. + * \return The string of the mutator. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyMutator's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) final { + ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + return this->f_apply(trace, *rand_state); + } + + static constexpr const char* _type_key = "meta_schedule.PyMutator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); +}; + +/*! + * \brief Managed reference to MutatorNode + * \sa MutatorNode + */ +class Mutator : public runtime::ObjectRef { + public: + /*! \brief Create a Mutator that mutates the tile size. */ + TVM_DLL static Mutator MutateTileSize(); + /*! + * \brief Create a Mutator that mutates the parallel extent + * \param max_jobs_per_core The maximum number of parallel jobs per core. + * \return The created mutator. + */ + TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core); + /*! \brief Create a Mutator that mutates auto unroll step */ + TVM_DLL static Mutator MutateUnroll(); + /*! + * \brief Create a mutator with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The mutator created. + */ + TVM_DLL static Mutator PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MUTATOR_H_ diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h new file mode 100644 index 0000000000..1a134a35ab --- /dev/null +++ b/include/tvm/meta_schedule/postproc.h @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_POSTPROC_H_ +#define TVM_META_SCHEDULE_POSTPROC_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! + * \brief Rules to apply a postprocessor to a schedule. + */ +class PostprocNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~PostprocNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a postprocessor to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the postprocessor was successfully applied. + */ + virtual bool Apply(const tir::Schedule& sch) = 0; + + static constexpr const char* _type_key = "meta_schedule.Postproc"; + TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); +}; + +/*! \brief The postprocessor with customized methods on the python-side. */ +class PyPostprocNode : public PostprocNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply a postprocessor to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the postprocessor was successfully applied. + */ + using FApply = runtime::TypedPackedFunc; + /*! + * \brief Get the postprocessor function as string with name. + * \return The string of the postprocessor function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyPostproc's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + bool Apply(const tir::Schedule& sch) final { + ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + return this->f_apply(sch); + } + + static constexpr const char* _type_key = "meta_schedule.PyPostproc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); +}; + +/*! + * \brief Managed reference to PostprocNode + * \sa PostprocNode + */ +class Postproc : public runtime::ObjectRef { + public: + /*! + * \brief Create a postprocessor with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The postprocessor created. + */ + TVM_DLL static Postproc PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string); + /*! + * \brief Create a postprocessor that checks if all loops are static + * \return The postprocessor created + */ + TVM_DLL static Postproc DisallowDynamicLoop(); + /*! + * \brief Create a postprocessor that rewrites the cooperative fetch annotation to + * actual vectorized cooperative fetching in loop bindings. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteCooperativeFetch(); + /*! + * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling + * according to the annotation of each block + * \return The postprocessor created + */ + TVM_DLL static Postproc RewriteParallelVectorizeUnroll(); + /*! + * \brief Create a postprocessor that rewrites reduction block by moving the init block out. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteReductionBlock(); + /*! + * \brief Create a postprocessor that adds thread binding to unbound blocks + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteUnboundBlock(); + /*! + * \brief Creates a postprocessor that verifies if the GPU code is correct + * \return The postprocessor created + */ + TVM_DLL static Postproc VerifyGPUCode(); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_POSTPROC_H_ diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h new file mode 100644 index 0000000000..6dba90ed26 --- /dev/null +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Rules to modify a block in a schedule. */ +class ScheduleRuleNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~ScheduleRuleNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a schedule rule to the specific block in the given schedule. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + virtual runtime::Array Apply(const tir::Schedule& sch, + const tir::BlockRV& block) = 0; + + static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; + TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); +}; + +/*! \brief The schedule rule with customized methods on the python-side. */ +class PyScheduleRuleNode : public ScheduleRuleNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `Apply` method. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + using FApply = + runtime::TypedPackedFunc(const tir::Schedule&, const tir::BlockRV&)>; + /*! + * \brief Get the schedule rule as string with name. + * \return The string of the schedule rule. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { + ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; + return this->f_apply(sch, block); + } + + static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); +}; + +/*! + * \brief Managed reference to ScheduleRuleNode + * \sa ScheduleRuleNode + */ +class ScheduleRule : public runtime::ObjectRef { + public: + /*! + * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions + * \brief into_producer If allows to inline a block into its producer + * \brief into_consumer If allows to inline a block into its consumer + * \brief into_cache_only If it only allows to inline into a block generated by cache_read/write + * \param inline_const_tensor Always inline constant tensors + * \param disallow_if_then_else Always disallow if-then-else-like constructs + * \param require_ordered Always require the read-to-write mapping to be ordered + * \param require_injective Always require the read-to-write mapping to be injective + * \param disallow_op The operators that are disallowed in auto inline + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule AutoInline(bool into_producer, // + bool into_consumer, // + bool into_cache_only, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op); + /*! + * \brief Create a mega rule: multi-level tiling with data reuse + * \param structure The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + * \param tile_bind For each level of tiles, which thread axis it is bound to. Recommended: + * - NullOpt on CPU + * - [blockIdx.x, vthread.x, threadIdx.x] on GPU + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // + Optional> tile_binds, // + Optional max_innermost_factor, // + Optional vector_load_max_len, // + Optional> reuse_read, // + Optional> reuse_write); + /*! + * \brief A rule that randomly select a compute-at location for a free block + * \return The rule created + */ + TVM_DLL static ScheduleRule RandomComputeLocation(); + /*! + * \brief Mark parallelize, vectorize and unroll to each block correspondingly + * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the + * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable + * parallelism. + * \param max_vectorize_extent The maximum extent to be vectorized. + * It sets the uplimit of the CPU vectorization. Use -1 to disable vectorization. + * \param unroll_max_steps The maximum number of unroll steps to be done. + * Use an empty array to disable unroll. + * \param unroll_explicit Whether to explicitly unroll the loop, or just add a unroll pragma. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + Array unroll_max_steps, // + bool unroll_explicit); + /*! + * \brief Create a schedule rule with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The schedule rule created. + */ + TVM_DLL static ScheduleRule PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 0f3e9298d1..e645f15ef1 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -28,6 +28,8 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class CostModel; +class Database; /*! \brief The schedule (with input shapes) to be measured. */ class MeasureCandidateNode : public runtime::Object { @@ -247,6 +249,39 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + /*! + * \brief Constructor of replay func search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for func replaying. + */ + TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + + /*! + * \brief Constructor of evolutionary search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for evolutionary search. + * \param population The initial sample population. + * \param max_replay_fail_cnt The maximum number to fail trace replaying. + * \param init_measured_ratio The ratio of measures samples in initial population. + * \param genetic_algo_iters The iterations to run the genetic algorithm. + * \param max_evolve_fail_cnt The maximum number to try evolving the given trace. + * \param p_mutate The probability of mutation. + * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score. + * \param database The database to use. + * \param cost_model The cost model to use. + */ + TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population, // + int max_replay_fail_cnt, // + double init_measured_ratio, // + int genetic_algo_iters, // + int max_evolve_fail_cnt, // + double p_mutate, // + double eps_greedy, // + Database database, // + CostModel cost_model); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index eadf5e9150..c0f37c6037 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -102,7 +102,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { */ using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; - /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; /*! \brief The packed function to the `GenerateDesignSpace` function. */ FGenerateDesignSpace f_generate_design_space; @@ -153,6 +153,14 @@ class SpaceGenerator : public ObjectRef { * \return The design space generator created. */ TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); + /*! + * \brief Create a design space generator that generates design spaces by applying schedule rules + * to blocks in post-DFS order. + * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`. + * \param generate_design_space_func The packed function of `GenerateDesignSpace`. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator PostOrderApply(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 64ba3ddeaf..bba78afc4b 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -73,8 +74,12 @@ class TaskSchedulerNode : public runtime::Object { Runner runner{nullptr}; /*! \brief The database of the scheduler. */ Database database{nullptr}; + /*! \brief The cost model of the scheduler. */ + Optional cost_model; + /*! \brief The list of measure callbacks of the scheduler. */ + Array measure_callbacks; - /*! \brief The default desctructor. */ + /*! \brief The default destructor. */ virtual ~TaskSchedulerNode() = default; void VisitAttrs(tvm::AttrVisitor* v) { @@ -82,6 +87,8 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("database", &database); + v->Visit("cost_model", &cost_model); + v->Visit("measure_callbacks", &measure_callbacks); } /*! \brief Auto-tuning. */ @@ -158,9 +165,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { */ using FNextTaskId = runtime::TypedPackedFunc; - /*! \brief The packed function to the `Tune` funcion. */ + /*! \brief The packed function to the `Tune` function. */ FTune f_tune; - /*! \brief The packed function to the `InitializeTask` funcion. */ + /*! \brief The packed function to the `InitializeTask` function. */ FInitializeTask f_initialize_task; /*! \brief The packed function to the `SetTaskStopped` function. */ FSetTaskStopped f_set_task_stopped; @@ -242,15 +249,19 @@ class TaskScheduler : public runtime::ObjectRef { * \param runner The runner of the scheduler. * \param database The database of the scheduler. */ - TVM_DLL static TaskScheduler RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database); // + TVM_DLL static TaskScheduler RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Optional cost_model, // + Optional> measure_callbacks); TVM_DLL static TaskScheduler PyTaskScheduler( Array tasks, // Builder builder, // Runner runner, // Database database, // + Optional cost_model, // + Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index db72328c91..eef7ae2b8d 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -20,6 +20,12 @@ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #include +#include +#include +#include +#include +#include +#include #include #include #include @@ -38,8 +44,14 @@ class TuneContextNode : public runtime::Object { Optional space_generator; /*! \brief The search strategy. */ Optional search_strategy; + /*! \brief The schedule rules. */ + Array sch_rules; + /*! \brief The postprocessors. */ + Array postprocs; + /*! \brief The probability of using certain mutator. */ + Map mutator_probs; /*! \brief The name of the tuning task. */ - Optional task_name; + String task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ @@ -47,24 +59,33 @@ class TuneContextNode : public runtime::Object { /*! \brief Whether the tuning task has been stopped or finished. */ bool is_stopped; - /*! \brief Packed functions to fetch the runner results asynchronously. */ - Optional> runner_futures; /*! \brief The measure candidates. */ Optional> measure_candidates; + /*! \brief The building results. */ + Optional> builder_results; + /*! \brief Packed functions to fetch the runner results asynchronously. */ + Optional> runner_futures; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); v->Visit("target", &target); v->Visit("space_generator", &space_generator); v->Visit("search_strategy", &search_strategy); + v->Visit("sch_rules", &sch_rules); + v->Visit("postprocs", &postprocs); + v->Visit("mutator_probs", &mutator_probs); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); v->Visit("is_stopped", &is_stopped); + v->Visit("builder_results", &builder_results); v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); } + /*! \brief Initialize members that needs initialization with tune context. */ + void Initialize(); + static constexpr const char* _type_key = "meta_schedule.TuneContext"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); }; @@ -81,6 +102,9 @@ class TuneContext : public runtime::ObjectRef { * \param target The target to be tuned for. * \param space_generator The design space generator. * \param search_strategy The search strategy. + * \param sch_rules The schedule rules. + * \param postprocs The postprocessors. + * \param mutator_probs The probability of using certain mutator. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -89,6 +113,9 @@ class TuneContext : public runtime::ObjectRef { Optional target, // Optional space_generator, // Optional search_strategy, // + Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index fcd2326050..89b1e9117f 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -29,6 +29,7 @@ #include #include // for uint64_t +#include namespace tvm { namespace support { @@ -73,6 +74,12 @@ class LinearCongruentialEngine { */ static constexpr result_type max() { return modulus - 1; } + /*! + * \brief Get a device random state + * \return The random state + */ + static TRandState DeviceRandom() { return (std::random_device()()) % modulus; } + /*! * \brief Operator to move the random state to the next and return the new random state. According * to definition of linear congruential engine, the new random state value is computed as @@ -93,6 +100,7 @@ class LinearCongruentialEngine { * \param rand_state The random state given in result_type. */ void Seed(TRandState rand_state = 1) { + ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!"; rand_state %= modulus; // Make sure the seed is within the range of modulus. if (rand_state == 0) rand_state = 1; // Avoid getting all 0 given the current parameter set. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index e482a18c4a..3efc419e4d 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,82 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! \brief A mapping from multi-dimensional indices to another set of multi-dimensional indices */ +class IndexMapNode : public Object { + public: + /*! \brief The source indices */ + Array src_iters; + /*! \brief The target indices */ + Array tgt_iters; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("src_iters", &src_iters); + v->Visit("tgt_iters", &tgt_iters); + } + + /*! + * \brief Take `inputs` as the source indices and return the corresponding target indices. + * \param inputs The source indices. + * \return The target indices. + */ + Array Apply(const Array& inputs) const; + + static constexpr const char* _type_key = "tir.IndexMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); +}; + +/*! + * \brief Managed reference to IndexMapNode. + * \sa IndexMapNode + */ +class IndexMap : public ObjectRef { + public: + /*! + * \brief Constructor. + * \param src_iters The source indices. + * \param tgt_iters The target indices. + */ + explicit IndexMap(Array src_iters, Array tgt_iters); + /*! + * \brief Create an index map from a packed function + * \param ndim The number of dimensions + * \param func The function to be applied + * \return The created index map + */ + static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func); + TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); +}; + +/*! + * \brief Tensor TensorIntrin for Tensorization + */ +class TensorIntrinNode : public Object { + public: + /*! \brief The function to describe the computation. */ + PrimFunc description; + /*! \brief The intrinsic function for lower-level implement. */ + PrimFunc implementation; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("description", &description); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "tir.TensorIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); +}; + +class TensorIntrin : public ObjectRef { + public: + TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func); + + TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, PrimFunc intrin_func); + + TVM_DLL static TensorIntrin Get(String name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) +}; + /*! * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index 5a9e687dc8..1af5ab07e6 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object { // not visited: f_attrs_from_json } + /*! \brief Checks if the instruction kind is EnterPostproc */ + bool IsPostproc() const; + static constexpr const char* _type_key = "tir.InstructionKind"; TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object); }; diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index ffd860d84c..498cd116ec 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -155,6 +155,12 @@ class ScheduleNode : public runtime::Object { * \return The corresponding loop sref */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Check the existance of a specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return Whether the corresponding block exists + */ + virtual bool HasBlock(const BlockRV& block_rv) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up @@ -204,6 +210,14 @@ class ScheduleNode : public runtime::Object { */ virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) = 0; + /*! + * \brief Sample a compute-at location on a BlockRV so that its producer can compute at that loop + * \param block_rv The consumer block to be computed at + * \param decision The sampling decision + * \return The sampled loop to be computed at + */ + virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! @@ -339,6 +353,11 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /******** Schedule: Data movement ********/ + virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -449,7 +468,53 @@ class ScheduleNode : public runtime::Object { virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) = 0; /******** Schedule: Blockize & Tensorize ********/ + /*! + * \brief Make subtree rooted by a specific loop into a block + * \param loop_rv The root of the subtree + * \return The new block + */ + virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with tensor_intrin + * \param loop_rv the loop/block to be tensorized + * \param intrin the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with tensor_intrin + * \param loop_rv The loop/block to be tensorized + * \param intrin_name Name of the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const String& intrin_name) = 0; + /******** Schedule: Annotation ********/ + /*! + * \brief Annotate a loop with a key value pair + * \param loop The loop to be annotated + * \param ann_key The annotation key + * \param ann_val The annotation value, a string or a ExprRV + */ + virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0; + /*! + * \brief Annotate a block with a key value pair + * \param loop The block to be annotated + * \param ann_key The annotation key + * \param ann_val The annotation value, a string or a ExprRV + */ + virtual void Annotate(const BlockRV& block_rv, const String& ann_key, + const ObjectRef& ann_val) = 0; + /*! + * \brief Unannotate a loop's annotation with key ann_key + * \param loop The loop to be unannotated + * \param ann_key The annotation key + */ + virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; + /*! + * \brief Unannotate a block's annotation with key ann_key + * \param loop The block to be unannotated + * \param ann_key The annotation key + */ + virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 066496704e..f8229ffb1b 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1224,7 +1224,7 @@ class BlockRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); }; -/*! \brief namespace of possible attribute sin AttrStmt.attr_key */ +/*! \brief namespace of possible attributes in AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. /*! \brief Mark launching extent of thread, used by device API. */ @@ -1357,6 +1357,43 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_ */ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; +/*! + * \brief Mark that the loop should be further skip and bound to environment threads to enable + * cooperative fetching. + */ +constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"; + +/*! + * \brief Mark a block as generated by cache_read or cache_write block. + * 0 means cache_read; 1 means cache_write. + * \sa meta_schedule_cache_type_read + * \sa meta_schedule_cache_type_write + */ +constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type"; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_read = 0; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_write = 1; + +/*! \brief Mark auto-parallel setting on the block. */ +constexpr const char* meta_schedule_parallel = "meta_schedule.parallel"; + +/*! \brief Mark auto-vectorize setting on the block. */ +constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; + +/*! \brief Pragma: auto-unroll, max_step */ +constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step"; + +/*! \brief Pragma: unroll explicit */ +constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit"; /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 47b3dda5a3..94a78e4e49 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -19,7 +19,14 @@ from . import database from . import builder from . import runner +from . import mutator +from . import postproc +from . import schedule_rule from . import space_generator from . import search_strategy from . import integration +from . import feature_extractor +from . import cost_model +from .search_strategy import MeasureCandidate, ReplayFuncConfig, ReplayTraceConfig from .tune_context import TuneContext +from .tune import tune_tir, tune_te diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index 381051e85f..7278c458a6 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule builders that translate IRModule to runtime.Module, and then export""" -from typing import List, Optional +from typing import List, Optional, Dict +from tvm.runtime import NDArray from tvm._ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object @@ -36,12 +37,20 @@ class BuilderInput(Object): The IRModule to be built. target : Target The target to be built for. + params: Optional[Dict[str, NDArray]] + The parameters for Relay build module """ mod: IRModule target: Target + params: Optional[Dict[str, NDArray]] - def __init__(self, mod: IRModule, target: Target) -> None: + def __init__( + self, + mod: IRModule, + target: Target, + params: Optional[Dict[str, NDArray]] = None, + ) -> None: """Constructor. Parameters @@ -50,11 +59,14 @@ def __init__(self, mod: IRModule, target: Target) -> None: The IRModule to be built. target : Target The target to be built for. + params: Optional[Dict[str, NDArray]] + The parameters for Relay build module """ self.__init_handle_by_constructor__( _ffi_api.BuilderInput, # type: ignore # pylint: disable=no-member mod, target, + params, ) diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index 99dfaea560..a1f1724d48 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -15,19 +15,35 @@ # specific language governing permissions and limitations # under the License. """Local builder that compile on the local host""" +import logging import os import tempfile -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from tvm._ffi import register_func from tvm.ir import IRModule -from tvm.runtime import Module +from tvm.runtime import NDArray +from tvm.runtime import Module, load_param_dict, save_param_dict from tvm.target import Target from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind from ..utils import cpu_count, get_global_func_with_default_on_worker from .builder import BuilderInput, BuilderResult, PyBuilder +logger = logging.getLogger(__name__) + + +def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]: + if params is None: + return None + return save_param_dict(params) + + +def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArray]]: + if params is None: + return None + return load_param_dict(params) + class LocalBuilder(PyBuilder): """A builder that builds the given input on local host. @@ -52,7 +68,11 @@ class LocalBuilder(PyBuilder): .. code-block:: python - def default_build(mod: IRModule, target: Target) -> Module: + def default_build( + mod: IRModule, + target: Target, + params: Optional[Dict[str, NDArray]] + ) -> Module: ... T_EXPORT : typing._GenericAlias @@ -71,7 +91,7 @@ def default_export(mod: Module) -> str: please send the registration logic via initializer. """ - T_BUILD = Callable[[IRModule, Target], Module] + T_BUILD = Callable[[IRModule, Target, Optional[Dict[str, NDArray]]], Module] T_EXPORT = Callable[[Module], str] pool: PopenPoolExecutor @@ -109,7 +129,8 @@ def __init__( super().__init__() if max_workers is None: - max_workers = cpu_count() + max_workers = cpu_count(logical=True) + logger.info("LocalBuilder: max_workers = %d", max_workers) self.pool = PopenPoolExecutor( max_workers=max_workers, @@ -134,6 +155,7 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: self.f_export, build_input.mod, build_input.target, + _serialize_params(build_input.params), ) for build_input in build_inputs ], @@ -172,6 +194,7 @@ def _worker_func( _f_export: Union[None, str, T_EXPORT], mod: IRModule, target: Target, + params: Optional[bytearray], ) -> str: # Step 0. Get the registered functions f_build: LocalBuilder.T_BUILD = get_global_func_with_default_on_worker( @@ -183,14 +206,14 @@ def _worker_func( default_export, ) # Step 1. Build the IRModule - rt_mod: Module = f_build(mod, target) + rt_mod: Module = f_build(mod, target, _deserialize_params(params)) # Step 2. Export the Module artifact_path: str = f_export(rt_mod) return artifact_path @register_func("meta_schedule.builder.default_build") -def default_build(mod: IRModule, target: Target) -> Module: +def default_build(mod: IRModule, target: Target, params: Optional[Dict[str, NDArray]]) -> Module: """Default build function. Parameters @@ -199,6 +222,8 @@ def default_build(mod: IRModule, target: Target) -> Module: The IRModule to be built. target : Target The target to be built. + params : Optional[Dict[str, NDArray]] + The parameters to be used for the build. Must be None. Returns ------- @@ -211,6 +236,7 @@ def default_build(mod: IRModule, target: Target) -> Module: # pylint: enable=import-outside-toplevel + assert params is None if target.kind.name == "cuda": set_cuda_target_arch(target.attrs["arch"]) diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py new file mode 100644 index 0000000000..8fc6f04ac9 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.cost_model package. +""" +from .cost_model import CostModel, PyCostModel +from .random_model import RandomModel +from .xgb_model import XGBModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py new file mode 100644 index 0000000000..13ca203c90 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule CostModel.""" + +from typing import List +import ctypes + +import numpy as np + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api +from ..runner import RunnerResult +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..utils import _get_hex_address, check_override + + +@register_object("meta_schedule.CostModel") +class CostModel(Object): + """Cost model.""" + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + """ + _ffi_api.CostModelLoad(self, path) # type: ignore # pylint: disable=no-member + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + """ + _ffi_api.CostModelSave(self, path) # type: ignore # pylint: disable=no-member + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member + + def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted normalized score. + """ + n = len(candidates) + results = np.zeros(shape=(n,), dtype="float64") + _ffi_api.CostModelPredict( # type: ignore # pylint: disable=no-member + self, + tune_context, + candidates, + results.ctypes.data_as(ctypes.c_void_p), + ) + return results + + +@register_object("meta_schedule.PyCostModel") +class PyCostModel(CostModel): + """An abstract CostModel with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, CostModel) + def f_load(path: str) -> None: + self.load(path) + + @check_override(self.__class__, CostModel) + def f_save(path: str) -> None: + self.save(path) + + @check_override(self.__class__, CostModel) + def f_update( + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + self.update(tune_context, candidates, results) + + @check_override(self.__class__, CostModel) + def f_predict( + tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr + ) -> None: + n = len(candidates) + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) + array_wrapper[:] = self.predict(tune_context, candidates) + assert ( + array_wrapper.dtype == "float64" + ), "ValueError: Invalid data type returned from CostModel Predict!" + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.CostModelPyCostModel, # type: ignore # pylint: disable=no-member + f_load, + f_save, + f_update, + f_predict, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/cost_model/metric.py b/python/tvm/meta_schedule/cost_model/metric.py new file mode 100644 index 0000000000..7eb6da6f07 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/metric.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Cost model metrics for meta schedule""" +from typing import List +import numpy as np + + +def max_curve(trial_scores: np.ndarray) -> List[float]: + """f(n) = max([s[i] fo i < n]) + + Parameters + ---------- + trial_scores : List[float] + the score of i-th trial + + Returns + ------- + curve : List[float] + function values + """ + ret = np.empty(len(trial_scores)) + keep = -1e9 + for i, score in enumerate(trial_scores): + keep = max(keep, score) + ret[i] = keep + return ret diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py new file mode 100644 index 0000000000..56c65f64af --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Random cost model +""" +from typing import List, Union, Tuple, Optional + +import numpy as np + +from ..runner import RunnerResult +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..cost_model import PyCostModel + + +class RandomModel(PyCostModel): + """Random cost model + + Parameters + ---------- + random_state : Union[Tuple[str, np.ndarray, int, int, float], dict] + The random state of the random number generator. + path : Optional[str] + The path of the random cost model. + max_range : Optional[int] + The maximum range of random results, [0, max_range]. + + Reference + --------- + https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html + """ + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + """ + self.random_state = tuple(np.load(path, allow_pickle=True)) + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + """ + np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + + def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted running results. + """ + np.random.set_state(self.random_state) + # todo(@zxybazh): Use numpy's RandState object: + # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py new file mode 100644 index 0000000000..5cc36db8a9 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -0,0 +1,665 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +XGBoost-based cost model +""" +from typing import NamedTuple, Optional, Tuple, Callable, List, TYPE_CHECKING + +import os +import logging +import tempfile +from itertools import chain as itertools_chain +import numpy as np + +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..feature_extractor import FeatureExtractor +from ..cost_model import PyCostModel +from ..utils import cpu_count +from .metric import max_curve +from ...contrib.tar import tar, untar + +if TYPE_CHECKING: + from ..tune_context import TuneContext + import xgboost as xgb + + +logger = logging.getLogger(__name__) + + +def make_metric_sorter(focused_metric): + """ Make sure the focused metric is the first one. """ + + def metric_name_for_sort(name): + if focused_metric == name: + return "!" + name + return name + + def sort_key(key): + key, _ = key + return metric_name_for_sort(key) + + return sort_key + + +class PackSum: + """The pack-sum format + + Parameters + ---------- + dmatrix : xgb.DMatrix + A float64 array of shape [n, m], + where `n` is the packed number of blocks, + and `m` is the length of feature vector on each block + ids : np.ndarray + An int64 array of shape [n] containing nonnegative integers, + indicating which the index of a sample that a block belongs to + """ + + dmatrix: "xgb.DMatrix" # type: ignore # pylint: disable=invalid-name + ids: np.ndarray + + def __init__( + self, + xs: List[np.ndarray], + ys: Optional[List[float]], + ): + """Create PackSum format given a batch of samples + + Parameters + ---------- + xs : List[np.ndarray] + A batch of input samples + ys : Optional[List[float]] + A batch of labels. None means no lables available. + """ + import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel + + repeats = [x.shape[0] for x in xs] + xs = np.concatenate(xs, axis=0) + self.ids = np.concatenate([[i] * repeat for i, repeat in enumerate(repeats)], axis=0) + if ys is None: + self.dmatrix = xgb.DMatrix(data=xs, label=None) + else: + ys = np.concatenate([[y] * repeat for y, repeat in zip(ys, repeats)], axis=0) + self.dmatrix = xgb.DMatrix(data=xs, label=ys) + self.dmatrix.set_weight(ys) + + def predict_with_score(self, pred: np.ndarray) -> np.ndarray: + """Predict the labels given the block level prediction scores. + + Parameters + ---------- + pred : np.ndarray + The block level predictions + + Returns + ------- + result : np.ndarray + The predictions for each candidate. + """ + return np.bincount(self.ids, weights=pred) + + def obj_square_error(self, ys_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Implement square error loss on pack-sum format as + a custom objective function for xgboost. + + Parameters + ---------- + ys_pred: np.ndarray + The predictions + + Returns + ------- + gradient: np.ndarray + The gradient according to the xgboost format + hessian: np.ndarray + The hessian according to the xgboost format + """ + # Making prediction + ys_pred = self.predict_with_score(ys_pred) + # Propagate prediction to each block + ys_pred = ys_pred[self.ids] + # The gradient and hessian + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + gradient = ys_pred - ys + hessian = np.ones_like(gradient) + return gradient * ys, hessian * ys + + def rmse(self, ys_pred: np.ndarray) -> Tuple[str, float]: + """Evaluate RMSE (rooted mean square error) in the pack-sum format + + Parameters + ---------- + ys_pred: np.ndarray + The raw predictions + + Returns + ------- + name: str + The name of the metric + score: float + The score of the metric + """ + # Making prediction + ys_pred = self.predict_with_score(ys_pred) + # Propagate prediction to each block + ys_pred = ys_pred[self.ids] + # The RMSE + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + square_error = np.square(ys_pred - ys) + rmse = np.sqrt(square_error.mean()) + return "p-rmse", rmse + + def average_peak_score( + self, + ys_pred: np.ndarray, + n: int, + ) -> Tuple[str, float]: + """Evaluate average-peak-score@N in the pack-sum format + + Parameters + ---------- + ys_pred: np.ndarray + The raw prediction + n : int + The N in average-peak-score@N + + Returns + ------- + name: str + The name of the metric + score: float + The score of the metric + """ + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + ys = self.predict_with_score(ys) # type: ignore # pylint: disable=invalid-name + ys = ys / np.unique(self.ids, return_counts=True)[1] # type: ignore # pylint: disable=invalid-name + ys_pred = self.predict_with_score(ys_pred) + trials = np.argsort(ys_pred)[::-1][:n] + trial_scores = ys[trials] + curve = max_curve(trial_scores) / np.max(ys) + score = np.mean(curve) + return f"a-peak@{n}", score + + +class XGBConfig(NamedTuple): + """XGBoost model configuration + + Parameters + ---------- + max_depth : int + The maximum depth. + gamma : float + The gamma. + min_child_weight : float + The minimum child weight. + eta : float + The eta, learning rate. + seed : int + The random seed. + nthread : Optional[int], + The number of threads to use. + Default is None, which means to use physical number of cores. + """ + + def to_dict(self): + xgb_params = { + "max_depth": self.max_depth, + "gamma": self.gamma, + "min_child_weight": self.min_child_weight, + "eta": self.eta, + "seed": self.seed, + "nthread": self.nthread, + } + return xgb_params + + max_depth: int = 10 + gamma: float = 0.001 + min_child_weight: float = 0 + eta: float = 0.2 + seed: int = 43 + nthread: Optional[int] = None + + +class XGBModel(PyCostModel): + """XGBoost model + + Parameters + ---------- + extractor : FeatureExtractor + The feature extractor for the model. + config : XGBConfig + The XGBoost model config. + num_warmup_samples : int + The number of samples that are used for warmup, i.e., the first few samples are predicted + with random results. + early_stopping_rounds : int + The number of rounds for early stopping. + verbose_eval : int + The verbose level when doing evaluation. + average_peak_n : int + The number to calculate average peak score. + """ + + # feature extractor + extractor: FeatureExtractor + # xgboost model config + config: XGBConfig + # behavior of randomness + num_warmup_samples: int + # evaluation + early_stopping_rounds: int + verbose_eval: int + average_peak_n: int + # states + cached_features: List[np.ndarray] + cached_mean_costs: np.ndarray + cached_normalizer: Optional[float] + booster: Optional["xgb.Booster"] + + def __init__( + self, + *, + # feature extractor + extractor: FeatureExtractor, + # xgboost model config + config: XGBConfig = XGBConfig(), + # load from disk + path: Optional[str] = None, + # behavior of randomness + num_warmup_samples: int = 100, + # evaluation + early_stopping_rounds: int = 50, + verbose_eval: int = 25, + average_peak_n: int = 32, + ): + super().__init__() + # feature extractor + self.extractor = extractor + # model-related + if config.nthread is None: + # use physical core number + config = config._replace(nthread=cpu_count(logical=False)) + self.config = config + # serialization-related + if path is not None: + self.load(path) + # behavior of randomness + self.num_warmup_samples = num_warmup_samples + # evaluation + self.early_stopping_rounds = early_stopping_rounds + self.verbose_eval = verbose_eval + self.average_peak_n = average_peak_n + # states + self.cached_features = [] + self.cached_mean_costs = np.empty((0,), dtype="float64") + self.cached_normalizer = None + self.booster = None + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + Since XGBoost model trains from scratch, each time we can only load the model without the + previous cached features / results so any call of update won't use previous training data. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + untar(path, tmpdirname) + self.booster.load_model(os.path.join(tmpdirname, "model.bin")) + self.cached_features = list( + np.load(os.path.join(tmpdirname, "cached_features.npy"), allow_pickle=True) + ) + self.cached_mean_costs = np.load( + os.path.join(tmpdirname, "cached_mean_costs.npy"), allow_pickle=True + ) + self.cached_normalizer = np.min(self.cached_mean_costs) + if self.cached_normalizer <= 0: + raise ValueError("The minimum mean cost must be greater than 0!") + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + Since XGBoost model trains from scratch, each time we can only save the model without the + previous cached features / results so any call of update won't use previous training data. + """ + import xgboost as xgb # pylint: disable=import-outside-toplevel + + if self.booster is None: + # save all the paramaters + self.booster = xgb.Booster(self.config.to_dict()) + with tempfile.TemporaryDirectory() as tmpdirname: + self.booster.save_model(os.path.join(tmpdirname, "model.bin")) + np.save( + os.path.join(tmpdirname, "cached_features.npy"), + np.array(self.cached_features, dtype=object), + ) + np.save(os.path.join(tmpdirname, "cached_mean_costs.npy"), self.cached_mean_costs) + tar( + path, + [ + os.path.join(tmpdirname, "model.bin"), + os.path.join(tmpdirname, "cached_features.npy"), + os.path.join(tmpdirname, "cached_mean_costs.npy"), + ], + ) + + def update( + self, + tune_context: "TuneContext", + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + assert len(candidates) == len(results) + if len(candidates) == 0: + return + # extract feature and do validation + new_features = [ + x.numpy().astype("float32") + for x in self.extractor.extract_from(tune_context, candidates) + ] + new_mean_costs = [float(sum(x.run_secs) / len(x.run_secs)) for x in results] + if self.booster is not None and self.cached_normalizer is not None: + logger.debug( + "XGB validation: %s", + "\t".join( + f"{key}: {score:.6f}" + for key, score in self._validate( + xs=new_features, + ys=new_mean_costs, + ) + ), + ) + # use together with previous features + self.cached_features.extend(new_features) + self.cached_mean_costs = np.append(self.cached_mean_costs, new_mean_costs) + self.cached_normalizer = np.min(self.cached_mean_costs) + if self.cached_normalizer <= 0: + raise ValueError("The minimum mean cost must be greater than 0!") + # train xgb model + self._train( + xs=self.cached_features, + ys=self.cached_mean_costs, + ) + + def predict( + self, tune_context: "TuneContext", candidates: List[MeasureCandidate] + ) -> np.ndarray: + """Predict the normalized score using the cost model. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted normalized score. + """ + n_measured = len(self.cached_features) + if self.booster is not None and n_measured >= self.num_warmup_samples: + features = self.extractor.extract_from(tune_context, candidates) + ret = self._predict(xs=[x.numpy().astype("float32") for x in features]) + else: + ret = np.random.uniform( + low=0, + high=1, + size=(len(candidates),), + ) + return ret.astype("float64") + + def _train( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ys: List[float], + ) -> None: + import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel + + self.d_train = PackSum( + xs=xs, + ys=self.cached_normalizer / ys, + ) + + def obj(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return self.d_train.obj_square_error(ys_pred) + + def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return self.d_train.rmse(ys_pred) + + def average_peak_score( + ys_pred: np.ndarray, d_train: "xgb.DMatrix" # type: ignore # pylint: disable = unused-argument + ): + return self.d_train.average_peak_score(ys_pred, self.average_peak_n) + + self.booster = xgb.train( + self.config.to_dict(), + self.d_train.dmatrix, + num_boost_round=10000, + obj=obj, + callbacks=[ + custom_callback( + early_stopping_rounds=self.early_stopping_rounds, + verbose_eval=self.verbose_eval, + fevals=[ + rmse, + average_peak_score, + ], + evals=[(self.d_train.dmatrix, "tr")], + ) + ], + ) + + del self.d_train + # todo(zxybazh): measure callback to save the model + + def _predict( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ) -> np.ndarray: + d_test = PackSum(xs=xs, ys=None) + pred = self.booster.predict(d_test.dmatrix) + ret = d_test.predict_with_score(pred) + return ret + + def _validate( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ys: List[float], + ) -> List[Tuple[str, float]]: + """Evaluate the score of inputs. + + Parameters + ---------- + xs : List[np.ndarray] + A batch of input samples + ys : List[float] + A batch of labels + + Returns + ------- + scores: np.ndarray + The predicted result for all inputs. + """ + if self.booster is None or self.cached_normalizer is None: + return [] + + d_valid = PackSum( + xs=xs, + ys=self.cached_normalizer / ys, + ) + + def average_peak_score(ys_pred: np.ndarray): + return d_valid.average_peak_score(ys_pred, n=self.average_peak_n) + + ys_pred = self.booster.predict(d_valid.dmatrix) + eval_result: List[Tuple[str, float]] = [ + feval(ys_pred) + for feval in ( + average_peak_score, + d_valid.rmse, + ) + ] + eval_result.sort(key=make_metric_sorter("p-rmse")) + return eval_result + + +def custom_callback( + early_stopping_rounds: int, + verbose_eval: int, + fevals: List[Callable], + evals: List[Tuple["xgb.DMatrix", str]], + focused_metric: str = "tr-p-rmse", +): + """Callback function for xgboost to support multiple custom evaluation functions""" + sort_key = make_metric_sorter(focused_metric=focused_metric) + + state = {} + + def init(env: "xgb.core.CallbackEnv"): + """Internal function""" + booster: "xgb.Booster" = env.model + + state["best_iteration"] = 0 + state["best_score"] = float("inf") + if booster is None: + assert env.cvfolds is not None + return + if booster.attr("best_score") is not None: + state["best_score"] = float(booster.attr("best_score")) + state["best_iteration"] = int(booster.attr("best_iteration")) + state["best_msg"] = booster.attr("best_msg") + else: + booster.set_attr(best_iteration=str(state["best_iteration"])) + booster.set_attr(best_score=str(state["best_score"])) + + def callback(env: "xgb.core.CallbackEnv"): + # pylint:disable = import-outside-toplevel + import xgboost as xgb + from xgboost.callback import _fmt_metric + from xgboost.core import EarlyStopException + + try: + from xgboost.training import aggcv + except ImportError: + from xgboost.callback import _aggcv as aggcv + # pylint:enable = import-outside-toplevel + + if not state: + init(env) + booster: xgb.Booster = env.model + iteration: int = env.iteration + cvfolds: List[xgb.training.CVPack] = env.cvfolds + ##### Evaluation ##### + # `eval_result` is a list of (key, score) + eval_result: List[Tuple[str, float]] = [] + if cvfolds is None: + eval_result = itertools_chain.from_iterable( + [ + (key, float(value)) + for key, value in map( + lambda x: x.split(":"), + booster.eval_set( + evals=evals, + iteration=iteration, + feval=feval, + ).split()[1:], + ) + ] + for feval in fevals + ) + else: + eval_result = itertools_chain.from_iterable( + [ + (key, score) + for key, score, _std in aggcv( + fold.eval( + iteration=iteration, + feval=feval, + ) + for fold in cvfolds + ) + ] + for feval in fevals + ) + eval_result = list(eval_result) + eval_result.sort(key=sort_key) + + ##### Print eval result ##### + if verbose_eval and iteration % verbose_eval == 0: + info = [] + for key, score in eval_result: + if "null" not in key: + info.append(f"{key}: {score:.6f}") + logger.debug("XGB iter %3d: %s", iteration, "\t".join(info)) + + ##### Choose score and do early stopping ##### + score = None + for key, _score in eval_result: + if key == focused_metric: + score = _score + break + assert score is not None + + best_score = state["best_score"] + best_iteration = state["best_iteration"] + if score < best_score: + tab = "\t" # to work with f-string + msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}" + state["best_msg"] = msg + state["best_score"] = score + state["best_iteration"] = env.iteration + # save the property to attributes, so they will occur in checkpoint. + if env.model is not None: + env.model.set_attr( + best_score=str(state["best_score"]), + best_iteration=str(state["best_iteration"]), + best_msg=state["best_msg"], + ) + elif env.iteration - best_iteration >= early_stopping_rounds: + best_msg = state["best_msg"] + if verbose_eval and env.rank == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + raise EarlyStopException(best_iteration) + + return callback diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index fd746e640c..b5ca5740b2 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -147,6 +147,21 @@ def from_json(json_obj: Any, workload: Workload) -> "TuningRecord": class Database(Object): """The abstract database interface.""" + def has_workload(self, mod: IRModule) -> bool: + """Check if the database has the given workload. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + + Returns + ------- + result : bool + Whether the database has the given workload. + """ + return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member + def commit_workload(self, mod: IRModule) -> Workload: """Commit a workload to the database if missing. @@ -207,6 +222,10 @@ class PyDatabase(Database): def __init__(self): """Constructor.""" + @check_override(self.__class__, Database) + def f_has_workload(mod: IRModule) -> bool: + return self.has_workload(mod) + @check_override(self.__class__, Database) def f_commit_workload(mod: IRModule) -> Workload: return self.commit_workload(mod) @@ -225,6 +244,7 @@ def f_size() -> int: self.__init_handle_by_constructor__( _ffi_api.DatabasePyDatabase, # type: ignore # pylint: disable=no-member + f_has_workload, f_commit_workload, f_commit_tuning_record, f_get_top_k, diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py new file mode 100644 index 0000000000..83ac7426cc --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.feature_extractor package. +Meta Schedule feature extractors that extracts features from +measure candidates for use in cost model. +""" +from .feature_extractor import FeatureExtractor, PyFeatureExtractor +from .per_store_feature import PerStoreFeature +from .random_feature_extractor import RandomFeatureExtractor diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py new file mode 100644 index 0000000000..bd7656e5be --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule FeatureExtractor.""" +from typing import List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.runtime.ndarray import NDArray + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate + + +@register_object("meta_schedule.FeatureExtractor") +class FeatureExtractor(Object): + """Extractor for features from measure candidates for use in cost model.""" + + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + """Extract features from the given measure candidate. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for feature extraction. + candidates : List[MeasureCandidate] + The measure candidates to extract features from. + + Returns + ------- + features : List[NDArray] + The feature numpy ndarray extracted. + """ + result = _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member + self, tune_context, candidates + ) + return result + + +@register_object("meta_schedule.PyFeatureExtractor") +class PyFeatureExtractor(FeatureExtractor): + """An abstract feature extractor with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, FeatureExtractor) + def f_extract_from( + tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + features = self.extract_from(tune_context, candidates) + return features + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.FeatureExtractorPyFeatureExtractor, # type: ignore # pylint: disable=no-member + f_extract_from, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py new file mode 100644 index 0000000000..30572ed5b9 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""We extract one feature vector per BufferStoreNode statement in a TIR Stmt, +so we call this feature as "per-store" feature. +""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .feature_extractor import FeatureExtractor + + +# /*! +# * \brief Create a feature extractor that extracts features from each BufferStore +# * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if +# * necessary. +# * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity +# * curve. +# * \param cache_line_bytes The number of bytes in a cache line. +# * \return The feature extractor created. +# */ + + +@register_object("meta_schedule.PerStoreFeature") +class PerStoreFeature(FeatureExtractor): + """PerStoreFeature extracts one feature vector per BufferStoreNode + + Parameters + ---------- + buffers_per_store : int + The number of buffers in each BufferStore; Pad or truncate if necessary. + arith_intensity_curve_num_samples : int + The number of samples used in the arithmetic intensity curve. + cache_line_bytes : int + The number of bytes in a cache line. + """ + + buffers_per_store: int + """The number of buffers in each BufferStore; Pad or truncate if necessary.""" + arith_intensity_curve_num_samples: int # pylint: disable=invalid-name + """The number of samples used in the arithmetic intensity curve.""" + cache_line_bytes: int + """The number of bytes in a cache line.""" + feature_vector_length: int + """Length of the feature vector.""" + + def __init__( + self, + buffers_per_store: int = 5, + arith_intensity_curve_num_samples: int = 10, + cache_line_bytes: int = 64, + ): + self.__init_handle_by_constructor__( + _ffi_api.FeatureExtractorPerStoreFeature, # type: ignore # pylint: disable=no-member + buffers_per_store, + arith_intensity_curve_num_samples, + cache_line_bytes, + ) diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py new file mode 100644 index 0000000000..f9f2f287fd --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Random Feature Extractor.""" +from typing import List, Union, Tuple + +import numpy as np +from tvm.runtime.ndarray import NDArray, array + +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..feature_extractor import PyFeatureExtractor + + +class RandomFeatureExtractor(PyFeatureExtractor): + """Random Feature Extractor + + Parameters + ---------- + feature_size : int + The size of each block's feature vector. + max_block_num : int + The maximum number of blocks in each schedule. + random_state : Union[Tuple[str, np.ndarray, int, int, float], dict] + The current random state of the f + """ + + feature_size: int + max_block_num: int + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + + def __init__(self, *, feature_size: int = 30, max_block_num: int = 5, seed=0): + super().__init__() + assert max_block_num >= 1, "Max block number must be greater or equal to one!" + self.max_block_num = max_block_num + self.feature_size = feature_size + np.random.seed(seed) + self.random_state = np.random.get_state() + + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + np.random.set_state(self.random_state) + result = [ + np.random.rand(np.random.randint(1, self.max_block_num + 1), self.feature_size) + for candidate in candidates + ] + self.random_state = np.random.get_state() + return [array(x) for x in result] diff --git a/python/tvm/meta_schedule/measure_callback/__init__.py b/python/tvm/meta_schedule/measure_callback/__init__.py new file mode 100644 index 0000000000..ef02cbdb55 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.measure_callback package. +""" +from .measure_callback import MeasureCallback, PyMeasureCallback +from .add_to_database import AddToDatabase +from .echo_statistics import EchoStatistics +from .remove_build_artifact import RemoveBuildArtifact diff --git a/python/tvm/meta_schedule/measure_callback/add_to_database.py b/python/tvm/meta_schedule/measure_callback/add_to_database.py new file mode 100644 index 0000000000..ab61e87f64 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/add_to_database.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A callback that adds the measurement results into the database""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .measure_callback import MeasureCallback + + +@register_object("meta_schedule.AddToDatabase") +class AddToDatabase(MeasureCallback): + def __init__(self) -> None: + """A callback that adds the measurement results into the database""" + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackAddToDatabase, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/measure_callback/echo_statistics.py b/python/tvm/meta_schedule/measure_callback/echo_statistics.py new file mode 100644 index 0000000000..867409f881 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/echo_statistics.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A callback that echos the statistics of the tuning process to the console""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .measure_callback import MeasureCallback + + +@register_object("meta_schedule.EchoStatistics") +class EchoStatistics(MeasureCallback): + def __init__(self) -> None: + """A callback that echos the statistics of the tuning process to the console""" + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackEchoStatistics, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py new file mode 100644 index 0000000000..fea888fac1 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule MeasureCallback.""" + +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object + +from ..search_strategy import MeasureCandidate +from ..builder import BuilderResult +from ..runner import RunnerResult +from ..utils import _get_hex_address, check_override + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..task_scheduler import TaskScheduler + + +@register_object("meta_schedule.MeasureCallback") +class MeasureCallback(Object): + """Rules to apply after measure results is available.""" + + def apply( + self, + task_scheduler: "TaskScheduler", + task_id: int, + measure_candidates: List[MeasureCandidate], + builder_results: List[BuilderResult], + runner_results: List[RunnerResult], + ) -> None: + """Apply a measure callback to the given schedule. + + Parameters + ---------- + task_scheduler: TaskScheduler + The task scheduler. + task_id: int + The task id. + measure_candidates: List[MeasureCandidate] + The measure candidates. + builder_results: List[BuilderResult] + The builder results by building the measure candidates. + runner_results: List[RunnerResult] + The runner results by running the built measure candidates. + """ + return _ffi_api.MeasureCallbackApply( # type: ignore # pylint: disable=no-member + self, + task_scheduler, + task_id, + measure_candidates, + builder_results, + runner_results, + ) + + +@register_object("meta_schedule.PyMeasureCallback") +class PyMeasureCallback(MeasureCallback): + """An abstract MeasureCallback with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, MeasureCallback) + def f_apply( + task_scheduler: "TaskScheduler", + task_id: int, + measure_candidates: List[MeasureCandidate], + builder_results: List[BuilderResult], + runner_results: List[RunnerResult], + ) -> None: + return self.apply( + task_scheduler, + task_id, + measure_candidates, + builder_results, + runner_results, + ) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyMeasureCallback({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py new file mode 100644 index 0000000000..4b2e1ab7f4 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A callback that removes the build artifacts from the disk""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .measure_callback import MeasureCallback + + +@register_object("meta_schedule.RemoveBuildArtifact") +class RemoveBuildArtifact(MeasureCallback): + def __init__(self) -> None: + """A callback that removes the build artifacts from the disk""" + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackRemoveBuildArtifact, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py new file mode 100644 index 0000000000..ca319046b1 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.mutator package. +Meta Schedule mutator that mutates the trace to explore the +design space. +""" +from .mutator import Mutator, PyMutator +from .mutate_parallel import MutateParallel +from .mutate_unroll import MutateUnroll +from .mutate_tile_size import MutateTileSize diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py new file mode 100644 index 0000000000..c66dddb825 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the parallel extent""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateParallel") +class MutateParallel(Mutator): + """Mutator that mutates the parallel extent""" + + def __init__(self, max_jobs_per_core: int) -> None: + """Mutator that mutates the parallel extent""" + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + ) diff --git a/python/tvm/meta_schedule/mutator/mutate_tile_size.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py new file mode 100644 index 0000000000..9c94d44361 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the tile size""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateTileSize") +class MutateTileSize(Mutator): + """Mutator that mutates the tile size""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateTileSize, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py new file mode 100644 index 0000000000..f81953d008 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates auto unroll step""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateUnroll") +class MutateUnroll(Mutator): + """Mutator that mutates auto unroll step""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py new file mode 100644 index 0000000000..d3b0085911 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Mutator.""" +from typing import Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Trace + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +class Mutator(Object): + """Mutator is designed to mutate the trace to explore the design space.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the mutator with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the mutator. + """ + _ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, trace: Trace) -> Optional[Trace]: + """Apply the mutator function to the given trace. + + Parameters + ---------- + trace : Trace + The given trace for mutation. + + Returns + ------- + trace : Optional[Trace] + None if mutator failed, otherwise return the mutated trace. + """ + return _ffi_api.MutatorApply(self, trace, -1) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyMutator") +class PyMutator(Mutator): + """An abstract mutator with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Mutator) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Mutator) + def f_apply(trace: Trace, _) -> Optional[Trace]: + return self.apply(trace) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py new file mode 100644 index 0000000000..96361e7391 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The tvm.meta_schedule.postproc package.""" +from .postproc import Postproc, PyPostproc +from .disallow_dynamic_loop import DisallowDynamicLoop +from .rewrite_cooperative_fetch import RewriteCooperativeFetch +from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll +from .rewrite_reduction_block import RewriteReductionBlock +from .rewrite_unbound_block import RewriteUnboundBlock +from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py new file mode 100644 index 0000000000..5515d288e0 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that checks if the IRModule has any loop with non-constant extent""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.DisallowDynamicLoop") +class DisallowDynamicLoop(Postproc): + """A postprocessor that checks if the IRModule has any loop with non-constant extent""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py new file mode 100644 index 0000000000..8e3b332c77 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Postproc.""" + +from typing import TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.Postproc") +class Postproc(Object): + """Rules to apply a postprocessor to a schedule.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the postprocessor with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the postprocessor. + """ + _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, sch: Schedule) -> bool: + """Apply a postprocessor to the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be post processed. + + Returns + ------- + result : bool + Whether the postprocessor was successfully applied. + """ + return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyPostproc") +class PyPostproc(Postproc): + """An abstract Postproc with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Postproc) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Postproc) + def f_apply(sch: Schedule) -> bool: + return self.apply(sch) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py new file mode 100644 index 0000000000..e2d7c22123 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites the cooperative fetch annotation to actual +vectorized cooperative fetching in loop bindings.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteCooperativeFetch") +class RewriteCooperativeFetch(Postproc): + """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized + cooperative fetching in loop bindings. + """ + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py new file mode 100644 index 0000000000..abe7288acb --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that applies parallelization, vectorization and auto unrolling +according to the annotation of each block""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteParallelVectorizeUnroll") +class RewriteParallelVectorizeUnroll(Postproc): + """A postprocessor that applies parallelization, vectorization and auto unrolling + according to the annotation of each block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py new file mode 100644 index 0000000000..7e15ed493c --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites reduction block by moving the init block out.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteReductionBlock") +class RewriteReductionBlock(Postproc): + """A postprocessor that rewrites reduction block by moving the init block out.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py new file mode 100644 index 0000000000..f4113e5173 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that adds thread binding to unbound blocks""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteUnboundBlock") +class RewriteUnboundBlock(Postproc): + """A postprocessor that adds thread binding to unbound blocks""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py new file mode 100644 index 0000000000..501e442319 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that verifies if the GPU code is correct""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.VerifyGPUCode") +class VerifyGPUCode(Postproc): + """A postprocessor that verifies if the GPU code is correct""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index caa266f97e..6af403905c 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -16,7 +16,9 @@ # under the License. """Local Runner""" from contextlib import contextmanager +import logging from typing import Callable, List, Optional, Union + import tvm from ...contrib.popen_pool import PopenPoolExecutor @@ -25,12 +27,14 @@ from .config import EvaluatorConfig from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult from .utils import ( - T_ARG_INFO_JSON_OBJ_LIST, T_ARGUMENT_LIST, + T_ARG_INFO_JSON_OBJ_LIST, alloc_argument_common, run_evaluator_common, ) +logger = logging.getLogger(__name__) + class LocalRunnerFuture(RunnerFuture): """Local based runner future @@ -214,6 +218,7 @@ def __init__( self.f_run_evaluator = f_run_evaluator self.f_cleanup = f_cleanup + logger.info("LocalRunner: max_workers = 1") self.pool = PopenPoolExecutor( max_workers=1, # one local worker timeout=timeout_sec, diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index 3ba1c1dccf..0e786b3cc5 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -17,6 +17,7 @@ """RPC Runner""" import concurrent.futures from contextlib import contextmanager +import logging import os.path as osp from typing import Callable, List, Optional, Union @@ -31,12 +32,14 @@ from .config import EvaluatorConfig, RPCConfig from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult from .utils import ( - T_ARG_INFO_JSON_OBJ_LIST, T_ARGUMENT_LIST, + T_ARG_INFO_JSON_OBJ_LIST, alloc_argument_common, run_evaluator_common, ) +logger = logging.getLogger(__name__) + class RPCRunnerFuture(RunnerFuture): """RPC based runner future @@ -275,6 +278,7 @@ def __init__( self.f_alloc_argument = f_alloc_argument self.f_run_evaluator = f_run_evaluator self.f_cleanup = f_cleanup + logger.info("RPCRunner: max_workers = %d", max_workers) self.pool = PopenPoolExecutor( max_workers=max_workers, timeout=rpc_config.session_timeout_sec, diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py new file mode 100644 index 0000000000..3fe3f0fb3b --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -0,0 +1,23 @@ +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.schedule_rule package. +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from .auto_inline import AutoInline +from .multi_level_tiling import MultiLevelTiling, ReuseType +from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll +from .random_compute_location import RandomComputeLocation +from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py new file mode 100644 index 0000000000..83828586bf --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AutoInline") +class AutoInline(ScheduleRule): + """Rule that inlines spatial blocks if it satisfies some conditions + + Parameters + ---------- + into_producer : bool + If allows to inline a block into its producer + into_consumer : bool + If allows to inline a block into its consumer + into_cache_only : bool + If it only allows to inline into a block generated by cache_read/write + inline_const_tensor : bool + Always inline constant tensors + disallow_if_then_else : bool + Always disallow if-then-else-like constructs + require_injective : bool + Always require the read-to-write mapping to be ordered + require_ordered : bool + Always require the read-to-write mapping to be injective + disallow_op : Optional[List[str]] + The operators that are disallowed in auto inline + """ + + def __init__( + self, + into_producer: bool, + into_consumer: bool, + into_cache_only: bool, + inline_const_tensor: bool, + disallow_if_then_else: bool, + require_injective: bool, + require_ordered: bool, + disallow_op: Optional[List[str]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member + into_producer, + into_consumer, + into_cache_only, + inline_const_tensor, + disallow_if_then_else, + require_injective, + require_ordered, + disallow_op, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py new file mode 100644 index 0000000000..669ede242e --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Multi-level tiling with reuse.""" +from typing import Any, Dict, List, Literal, NamedTuple, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +class ReuseType(NamedTuple): + """Reuse type.""" + + req: Literal["no", "may", "must"] + levels: List[int] + scope: str + + def as_dict(self) -> Dict[str, Any]: + """Return the dict representation of the reuse type.""" + return { + "req": self.req, + "levels": self.levels, + "scope": self.scope, + } + + +@register_object("meta_schedule.MultiLevelTiling") +class MultiLevelTiling(ScheduleRule): + """Multi-level tiling with reuse. + + Parameters + ---------- + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_max_len : Optional[int] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_max_len: Optional[int] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member + structure, + tile_binds, + max_innermost_factor, + vector_load_max_len, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py new file mode 100644 index 0000000000..36513022a9 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rule that mark parallelize, vectorize and unroll to each block correspondingly""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.ParallelizeVectorizeUnroll") +class ParallelizeVectorizeUnroll(ScheduleRule): + """Rule that mark parallelize, vectorize and unroll to each block correspondingly + + Parameters + ---------- + max_jobs_per_core: int + The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU + parallelism, i.e. `num_cores * max_jobs_per_core`. + Use -1 to disable parallelism. + max_vectorize_extent: int + The maximum extent to be vectorized. It sets the uplimit of the CPU vectorization. + Use -1 to disable vectorization. + unroll_max_steps: Optional[List[int]] + The maximum number of unroll steps to be done. + Use None to disable unroll + unroll_explicit: bool + Whether to explicitly unroll the loop, or just add a unroll pragma + """ + + def __init__( + self, + max_jobs_per_core: int = 16, + max_vectorize_extent: int = 16, + unroll_max_steps: Optional[List[int]] = None, + unroll_explicit: bool = True, + ) -> None: + if unroll_max_steps is None: + unroll_max_steps = [] + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleParallelizeVectorizeUnroll, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + max_vectorize_extent, + unroll_max_steps, + unroll_explicit, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py new file mode 100644 index 0000000000..2355b0bfa8 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rule that randomly select a compute-at location for a free block""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.RandomComputeLocation") +class RandomComputeLocation(ScheduleRule): + """A rule that randomly select a compute-at location for a free block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py new file mode 100644 index 0000000000..b995e5acb6 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule, BlockRV + +from ..utils import _get_hex_address, check_override +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.ScheduleRule") +class ScheduleRule(Object): + """Rules to modify a block in a schedule.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the schedule rule with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the schedule rule. + """ + _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + """Apply a schedule rule to the specific block in the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be modified. + block : BlockRV + The specific block to apply the schedule rule. + + Returns + ------- + design_spaces : List[Schedule] + The list of schedules generated by applying the schedule rule. + """ + return _ffi_api.ScheduleRuleApply( # type: ignore # pylint: disable=no-member + self, sch, block + ) + + +@register_object("meta_schedule.PyScheduleRule") +class PyScheduleRule(ScheduleRule): + """An abstract schedule rule with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, ScheduleRule) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, ScheduleRule) + def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: + return self.apply(sch, block) + + def f_as_string() -> str: + return self.__str__() + + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRulePyScheduleRule, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 609baa2677..6102ebc41a 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -20,5 +20,7 @@ to generate measure candidates. """ -from .search_strategy import SearchStrategy, PySearchStrategy -from .replay_trace import ReplayTrace +from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate +from .replay_trace import ReplayTrace, ReplayTraceConfig +from .replay_func import ReplayFunc, ReplayFuncConfig +from .evolutionary_search import EvolutionarySearch diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py new file mode 100644 index 0000000000..67bb2dd85a --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Evolutionary Search Strategy""" + +from typing import TYPE_CHECKING, Dict + +from tvm._ffi import register_object + +from .search_strategy import SearchStrategy +from ..mutator import Mutator +from ..database import Database + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..cost_model import CostModel + + +@register_object("meta_schedule.EvolutionarySearch") +class EvolutionarySearch(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + population : int + The initial population of traces from measured samples and randomly generated samples. + max_replay_fail_cnt : int + The maximum number to fail trace replaying. + init_measured_ratio : int + The ratio of measured samples in the initial population. + genetic_algo_iters : int + The number of iterations for genetic algorithm. + max_evolve_fail_cnt : int + The maximum number to retry mutation. + p_mutate : float + The probability of mutation. + eps_greedy : float + The ratio of greedy selected samples in the final picks. + database : Database + The database used in the search. + cost_model : CostModel + The cost model used in the search. + """ + + num_trials_per_iter: int + num_trials_total: int + population: int + init_measured_ratio: int + genetic_algo_iters: int + max_replay_fail_cnt: int + max_evolve_fail_cnt: int + p_mutate: float + eps_greedy: float + database: Database + cost_model: "CostModel" + + def __init__( + self, + *, + num_trials_per_iter: int, + num_trials_total: int, + database: Database, + cost_model: "CostModel", + population: int = 2048, + max_replay_fail_cnt: int = 64, + init_measured_ratio: float = 0.2, + genetic_algo_iters: int = 10, + max_evolve_fail_cnt: int = 10, + p_mutate: float = 0.85, + eps_greedy: float = 0.25, + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + population, + max_replay_fail_cnt, + init_measured_ratio, + genetic_algo_iters, + max_evolve_fail_cnt, + p_mutate, + eps_greedy, + database, + cost_model, + ) diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py new file mode 100644 index 0000000000..34eadc7a3f --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" +from typing import NamedTuple + +from tvm._ffi import register_object + +from .. import _ffi_api +from .search_strategy import SearchStrategy + + +@register_object("meta_schedule.ReplayFunc") +class ReplayFunc(SearchStrategy): + """ + Replay Func Search Strategy is a search strategy that generates measure candidates by + calling a design space generator and transform the design space. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__( + self, + num_trials_per_iter: int, + num_trials_total: int, + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyReplayFunc, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) + + +class ReplayFuncConfig(NamedTuple): + """Configuration for ReplayFunc""" + + num_trials_per_iter: int + num_trials_total: int + + def create_strategy(self) -> ReplayFunc: + return ReplayFunc(self.num_trials_per_iter, self.num_trials_total) diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 15f8295f25..f550135460 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" +from typing import NamedTuple from tvm._ffi import register_object from .search_strategy import SearchStrategy @@ -41,7 +42,17 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyReplayTrace, # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) + + +class ReplayTraceConfig(NamedTuple): + """Configuration for ReplayTrace""" + + num_trials_per_iter: int + num_trials_total: int + + def create_strategy(self) -> ReplayTrace: + return ReplayTrace(self.num_trials_per_iter, self.num_trials_total) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 6cee09edd4..25a03aaf87 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -48,7 +48,11 @@ class MeasureCandidate(Object): sch: Schedule args_info: List[ArgInfo] - def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + def __init__( + self, + sch: Schedule, + args_info: List[ArgInfo], + ) -> None: """Constructor. Parameters @@ -72,10 +76,7 @@ class SearchStrategy(Object): before usage and post-tuned after usage. """ - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the search strategy with tuning context. Parameters diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py index af759d43b3..fc08cd491d 100644 --- a/python/tvm/meta_schedule/space_generator/__init__.py +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -19,7 +19,7 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ - from .space_generator import SpaceGenerator, PySpaceGenerator from .space_generator_union import SpaceGeneratorUnion from .schedule_fn import ScheduleFn +from .post_order_apply import PostOrderApply diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py new file mode 100644 index 0000000000..a9b2d56031 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Post Order Apply Space Generator.""" + + +from tvm._ffi import register_object +from .space_generator import SpaceGenerator +from .. import _ffi_api + + +@register_object("meta_schedule.PostOrderApply") +class PostOrderApply(SpaceGenerator): + """ + PostOrderApply is the design space generator that generates design spaces by applying schedule + rules to blocks in post-DFS order. + """ + + def __init__(self): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorPostOrderApply, # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index e37fd14ba4..2172613ce1 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -36,10 +36,7 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 391011b4f5..a63d9a3f21 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -16,13 +16,15 @@ # under the License. """Round Robin Task Scheduler""" -from typing import List, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner from ..database import Database +from ..cost_model import CostModel from .task_scheduler import TaskScheduler from .. import _ffi_api @@ -33,7 +35,21 @@ @register_object("meta_schedule.RoundRobin") class RoundRobin(TaskScheduler): - """Round Robin Task Scheduler""" + """Round Robin Task Scheduler + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: Optional[List[MeasureCallback]] = None + The list of measure callbacks of the scheduler. + """ def __init__( self, @@ -41,6 +57,8 @@ def __init__( builder: Builder, runner: Runner, database: Database, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, ) -> None: """Constructor. @@ -54,6 +72,8 @@ def __init__( The runner. database : Database The database. + measure_callbacks: Optional[List[MeasureCallback]] + The list of measure callbacks of the scheduler. """ self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member @@ -61,4 +81,6 @@ def __init__( builder, runner, database, + cost_model, + measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index aeea154cfe..dd8e3fe89b 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -16,14 +16,16 @@ # under the License. """Auto-tuning Task Scheduler""" -from typing import List +from typing import List, Optional from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from tvm.runtime import Object from ..runner import Runner from ..builder import Builder from ..database import Database +from ..cost_model import CostModel from ..tune_context import TuneContext from .. import _ffi_api from ..utils import check_override @@ -43,12 +45,16 @@ class TaskScheduler(Object): The runner of the scheduler. database: Database The database of the scheduler. + measure_callbacks: List[MeasureCallback] = None + The list of measure callbacks of the scheduler. """ tasks: List[TuneContext] builder: Builder runner: Runner database: Database + cost_model: Optional[CostModel] + measure_callbacks: List[MeasureCallback] def tune(self) -> None: """Auto-tuning.""" @@ -59,7 +65,7 @@ def next_task_id(self) -> int: Returns ------- - int + next_task_id : int The next task id. """ return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member @@ -94,7 +100,7 @@ def _is_task_running(self, task_id: int) -> bool: Returns ------- - bool + running : bool Whether the task is running. """ return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member @@ -120,6 +126,8 @@ def __init__( builder: Builder, runner: Runner, database: Database, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, ): """Constructor. @@ -133,6 +141,10 @@ def __init__( The runner of the scheduler. database: Database The database of the scheduler. + cost_model: Optional[CostModel] + The cost model of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ @check_override(self.__class__, TaskScheduler, required=False) @@ -173,6 +185,8 @@ def f_join_running_task(task_id: int) -> None: builder, runner, database, + cost_model, + measure_callbacks, f_tune, f_initialize_task, f_set_task_stopped, diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 7e516a510f..b64891a385 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -15,5 +15,8 @@ # specific language governing permissions and limitations # under the License. """Testing utilities in meta schedule""" +from . import te_workload +from . import schedule_rule from .local_rpc import LocalRPC -from .relay_workload import get_network +from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model +from .te_workload import create_te_workload diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py new file mode 100644 index 0000000000..01d4a1fe3c --- /dev/null +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TensorRT-MetaSchedule integration""" +# pylint: disable=import-outside-toplevel + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + from tvm.ir import IRModule + from tvm.target import Target + from tvm.runtime import NDArray, Module, Device + from tvm.meta_schedule.runner import EvaluatorConfig + + +def build_relay( + mod: "IRModule", + target: "Target", + params: Dict[str, "NDArray"], +) -> "Module": + """Build a Relay IRModule + + Parameters + ---------- + mod : IRModule + The Relay IRModule to build. + target : Target + The target to build the module for. + params : Dict[str, NDArray] + The parameter dict to build the module with. + + Returns + ------- + mod : runtime.Module + The built module. + """ + from tvm.relay.build_module import _build_module_no_factory as relay_build + from tvm.runtime import Module + + result = relay_build(mod, target=target, target_host=None, params=params) + assert isinstance(result, Module) + return result + + +def build_relay_with_tensorrt( + mod: "IRModule", + target: "Target", + params: Dict[str, "NDArray"], +) -> "Module": + """Build a Relay IRModule with TensorRT BYOC + + Parameters + ---------- + mod : IRModule + The Relay IRModule to build. + + target : Target + The target to build the module for. + + params : Dict[str, NDArray] + The parameter dict to build the module with. + + Returns + ------- + mod : runtime.Module + The built module. + """ + from tvm.ir.transform import PassContext + from tvm.relay.op.contrib import tensorrt + from tvm.relay.build_module import _build_module_no_factory as relay_build + from tvm.runtime import Module + + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with PassContext( + opt_level=3, + config={"relay.ext.tensorrt.options": config}, + ): + result = relay_build(mod, target=target, target_host=None, params=params) + assert isinstance(result, Module) + return result + + +def run_with_graph_executor( + rt_mod: "Module", + device: "Device", + evaluator_config: "EvaluatorConfig", + repeated_args: List["NDArray"], +) -> List[float]: + """Run a Relay module with GraphExecutor + + Parameters + ---------- + rt_mod : Module + The Relay module to run. + device : Device + The device to run the module on. + evaluator_config : EvaluatorConfig + The evaluator configuration to run the module with. + repeated_args : List[NDArray] + The list of repeated arguments to run the module with. + + Returns + ------- + results : List[float] + The list of results. + """ + import itertools + from tvm.contrib.graph_executor import GraphModule + + graph_mod = GraphModule(rt_mod["default"](device)) + evaluator = graph_mod.module.time_evaluator( + func_name="run", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs = [] + for args in repeated_args: + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 1eb9950f7f..ec6f65aff0 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -15,13 +15,178 @@ # specific language governing permissions and limitations # under the License. """Workloads in Relay IR""" +from enum import Enum from typing import Dict, Tuple -import tvm.relay.testing # pylint: disable=unused-import from tvm import relay from tvm.ir import IRModule from tvm.runtime import NDArray +# Model types supported in Torchvision +class MODEL_TYPE(Enum): # pylint: disable=invalid-name + IMAGE_CLASSIFICATION = (1,) + VIDEO_CLASSIFICATION = (2,) + SEGMENTATION = (3,) + OBJECT_DETECTION = (4,) + + +# Specify the type of each model +MODEL_TYPES = { + # Image classification models + "resnet50": MODEL_TYPE.IMAGE_CLASSIFICATION, + "alexnet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "vgg16": MODEL_TYPE.IMAGE_CLASSIFICATION, + "squeezenet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet121": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet161": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet169": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet201": MODEL_TYPE.IMAGE_CLASSIFICATION, + "inception_v3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "googlenet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "shufflenet_v2_x1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_large": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_small": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnext50_32x4d": MODEL_TYPE.IMAGE_CLASSIFICATION, + "wide_resnet50_2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mnasnet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b1": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b4": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b5": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b6": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b7": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + # Semantic Segmentation models + "fcn_resnet50": MODEL_TYPE.SEGMENTATION, + "fcn_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet50": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + "lraspp_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + # Object detection models + # @Sung: Following networks are not runnable since Torch frontend cannot handle aten::remainder. + # "retinanet_resnet50_fpn", "keypointrcnn_resnet50_fpn", + "fasterrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_320_fpn": MODEL_TYPE.OBJECT_DETECTION, + "retinanet_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "maskrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "keypointrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "ssd300_vgg16": MODEL_TYPE.OBJECT_DETECTION, + "ssdlite320_mobilenet_v3_large": MODEL_TYPE.OBJECT_DETECTION, + # Video classification + "r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, +} + + +def get_torch_model( + model_name: str, + input_shape: Tuple[int, ...], + output_shape: Tuple[int, int], # pylint: disable=unused-argument + dtype: str = "float32", +) -> Tuple[IRModule, Dict[str, NDArray]]: + """Load model from torch model zoo + Parameters + ---------- + model_name : str + The name of the model to load + input_shape: Tuple[int, ...] + Tuple for input shape + output_shape: Tuple[int, int] + Tuple for output shape + dtype: str + Tensor data type + """ + + assert dtype == "float32" + + import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel + from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel + + def do_trace(model, inp): + model_trace = torch.jit.trace(model, inp) + model_trace.eval() + return model_trace + + # Load model from torchvision + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + model = getattr(models, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + model = getattr(models.segmentation, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + model = getattr(models.detection, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + model = getattr(models.video, model_name)() + else: + raise ValueError("Unsupported model in Torch model zoo.") + + # Setup input + input_data = torch.randn(input_shape).type(torch.float32) + shape_list = [("input0", input_shape)] + + # Get trace. Depending on the model type, wrapper may be necessary. + if MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + scripted_model = do_trace(wrapped_model, input_data) + + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + + def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + class TraceWrapper(torch.nn.Module): # type: ignore + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + _ = wrapped_model(input_data) + scripted_model = do_trace(wrapped_model, input_data) + else: + scripted_model = do_trace(model, input_data) + + # Convert torch model to relay module + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params + def get_network( name: str, @@ -30,6 +195,8 @@ def get_network( dtype: str = "float32", ) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: """Get the symbol definition and random weight of a network""" + import tvm.relay.testing # pylint: disable=import-outside-toplevel,unused-import + # meta-schedule prefers NHWC layout if layout == "NHWC": image_shape = (224, 224, 3) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py new file mode 100644 index 0000000000..03973488ac --- /dev/null +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default schedule rules""" +from typing import List + +from tvm.meta_schedule.schedule_rule import ( + AutoInline, + MultiLevelTiling, + ParallelizeVectorizeUnroll, + RandomComputeLocation, + ReuseType, + ScheduleRule, +) +from tvm.target import Target + + +def get(target: Target) -> List[ScheduleRule]: + """Default schedule rules""" + if target.kind.name == "llvm": + return [ + auto_inline(target), + multi_level_tiling(target), + parallel_vectorize_unroll(target), + ] + if target.kind.name == "cuda": + return [ + auto_inline(target), + multi_level_tiling(target), + auto_inline_after_tiling(target), + parallel_vectorize_unroll(target), + ] + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def auto_inline(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def auto_inline_after_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline after tiling""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def multi_level_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling and reuse""" + if target.kind.name == "llvm": + return MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_max_len=None, + reuse_read=None, + reuse_write=ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ) + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_max_len=4, + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def parallel_vectorize_unroll(target: Target) -> ScheduleRule: + """Default schedule rules for with parallel-vectorize-unroll""" + if target.kind.name == "llvm": + return ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=32, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ) + if target.kind.name == "cuda": + return ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, + max_vectorize_extent=-1, + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def random_compute_location(target: Target) -> ScheduleRule: + """Default schedule rules for with random-compute-location""" + if target.kind.name == "llvm": + return RandomComputeLocation() + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py new file mode 100644 index 0000000000..4abf090ddf --- /dev/null +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List, Union + +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.target import Target +from tvm.tir import PrimFunc, Schedule +from tvm.tir.schedule import Trace + +from . import schedule_rule as sch_rule + + +def create_context(mod: Union[IRModule, PrimFunc], target: Target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=sch_rule.get(target), + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for rule in ctx.sch_rules: + rule.initialize_with_tune_context(ctx) + return ctx + + +def check_trace(spaces: List[Schedule], expected: List[List[str]]): + expected_traces = {"\n".join(t) for t in expected} + actual_traces = set() + for space in spaces: + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + str_trace = "\n".join(str(trace).strip().splitlines()) + actual_traces.add(str_trace) + assert str_trace in expected_traces, "\n" + str_trace + assert len(expected_traces) == len(actual_traces) + + +def debug_print_spaces(spaces: List[Schedule], trace_as_list: bool) -> None: + for i, space in enumerate(spaces): + print(f"##### Space {i}") + print(space.mod.script()) + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + if trace_as_list: + print(str(trace).strip().splitlines()) + else: + print(trace) diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py new file mode 100644 index 0000000000..d57bea86e4 --- /dev/null +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -0,0 +1,831 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workloads in TE""" +# pylint: disable=missing-docstring +from typing import Tuple + +from tvm import te, tir, topi + + +def batch_matmul_nkkm( # pylint: disable=invalid-name,missing-docstring + B: int, + N: int, + M: int, + K: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((B, N, K), name="X") + y = te.placeholder((B, K, M), name="Y") + k = te.reduce_axis((0, K), name="k") + z = te.compute( # pylint: disable=invalid-name + (B, N, M), + lambda b, i, j: te.sum(x[b][i][k] * y[b][k][j], axis=[k]), + name="Z", + ) + return (x, y, z) + + +def conv1d_nlc( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, output) + + +def conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI // groups, CO), name="weight") + batch_size, in_h, in_w, _ = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rh, rw, rc, co] + ), + axis=[rh, rw, rc], + ), + name="conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv3d_ndhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + D: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, D, H, W, CI)) + weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI // groups, CO)) + batch_size, in_d, in_h, in_w, _ = inputs.shape + k_d, k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rd = te.reduce_axis((0, k_d), name="rd") + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) + output = te.compute( + (batch_size, out_d, out_h, out_w, out_channel), + lambda n, d, h, w, co: te.sum( + ( + padded[ + n, + d * stride + rd * dilation, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rd, rh, rw, rc, co] + ), + axis=[rd, rh, rw, rc], + ), + name="conv3d_ndhwc", + ) + return (inputs, weight, output) + + +def depthwise_conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + C: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + factor: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, C)) + weight = te.placeholder((factor, kernel_size, kernel_size, C)) + batch_size, in_h, in_w, in_channel = inputs.shape + factor, k_h, k_w, in_channel = weight.shape + out_channel = in_channel * factor + assert int(factor) == 1, "Not optimized for factor != 1" + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, c: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + c // factor, + ] + * weight[c % factor, rh, rw, c // factor] + ), + axis=[rh, rw], + ), + name="depth_conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_transpose_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI, CO), name="weight") + + batch, in_h, in_w, in_c = inputs.shape + filter_h, filter_w, in_c, out_c = weight.shape + stride_h, stride_w = (stride, stride) + + # compute padding + fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple( + padding, (filter_h, filter_w) + ) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + padded = topi.nn.pad( + inputs, + [ + 0, + (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, + 0, + ], + [ + 0, + (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, + 0, + ], + ) + + # remove extra padding introduced by dilatation + idx_div = te.indexdiv + idx_mod = te.indexmod + border_h = idx_mod(stride_h - idx_mod(bpad_top, stride_h), stride_h) + border_w = idx_mod(stride_w - idx_mod(bpad_left, stride_w), stride_w) + + # dilation stage + strides = [1, stride_h, stride_w, 1] + n = len(padded.shape) + + # We should embed this dilation directly into te.compute rather than creating a new te.compute. + # Only in this way can we use unroll to eliminate the multiplication of zeros. + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not strides[i] == 1: + index_tuple.append(idx_div(indices[i], strides[i])) + not_zero.append(idx_mod(indices[i], strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = te.all(*not_zero) + return te.if_then_else(not_zero, padded(*index_tuple), tir.const(0.0, padded.dtype)) + return padded(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + rc = te.reduce_axis((0, in_c), name="rc") + rh = te.reduce_axis((0, filter_h), name="rh") + rw = te.reduce_axis((0, filter_w), name="rw") + + output = te.compute( + (batch, out_h, out_w, out_c), + lambda n, h, w, co: te.sum( + _dilate(n, h + rh + border_h, w + rw + border_w, rc) + * weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], + axis=[rh, rw, rc], + ), + name="conv2d_transpose_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_capsule_nhwijc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + capsule_size: int = 4, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name="weight" + ) + batch_size, in_h, in_w, _, _, in_channel = inputs.shape + k_h, k_w, _, _, _, out_channel = weight.shape + + out_h = (in_h + 2 * padding - kernel_size) // stride + 1 + out_w = (in_w + 2 * padding - kernel_size) // stride + 1 + + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + cap_k = te.reduce_axis((0, capsule_size), name="cap_k") + rc = te.reduce_axis((0, in_channel), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) + output = te.compute( + (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), + lambda n, h, w, cap_i, cap_j, co: te.sum( + ( + padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] + * weight[rh, rw, cap_k, cap_j, rc, co] + ), + axis=[rh, rw, cap_k, rc], + ), + name="conv2d_capsule_nhwijc", + ) + return (inputs, weight, output) + + +def norm_bmn( # pylint: disable=invalid-name,missing-docstring + B: int, + M: int, + N: int, +) -> Tuple[te.Tensor, te.Tensor]: + a = te.placeholder((B, M, N), name="A") + i = te.reduce_axis((0, M), name="i") + j = te.reduce_axis((0, N), name="j") + c = te.compute( + (B,), + lambda b: te.sum(a[b][i][j] * a[b][i][j], axis=[i, j]), + name="C", + ) + d = te.compute((B,), lambda b: te.sqrt(c[b]), name="D") + return (a, d) + + +def conv2d_nhwc_without_layout_rewrite( # pylint: disable=invalid-name + Input: int, + Filter: int, + stride: int, + padding: int, + dilation: int, + out_dtype="float32", +): + """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. + We use this in single op and subgraph evaluation + because we don't want to introduce graph level optimization. + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape # type: ignore + kernel_h, kernel_w, _channel, num_filter = Filter.shape # type: ignore + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel = num_filter + out_height = topi.utils.simplify( + (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 + ) + out_width = topi.utils.simplify( + (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 + ) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + PaddedInput[ + nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc + ].astype(out_dtype) + * Filter[ry, rx, rc, ff].astype(out_dtype), # type: ignore + axis=[ry, rx, rc], + ), + name="Conv2dOutput", + tag="conv2d_nhwc", + ) + return Output + + +def conv2d_nhwc_bn_relu( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + strides: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + data = te.placeholder((N, H, W, CI), name="data") + kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name="kernel") + bias = te.placeholder((CO,), name="bias") + bn_scale = te.placeholder((CO,), name="bn_scale") + bn_offset = te.placeholder((CO,), name="bn_offset") + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bias[l], name="bias_add" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], name="bn_mul" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], name="bn_add" + ) + out = topi.nn.relu(conv) + return (data, kernel, bias, bn_offset, bn_scale, out) + + +def transpose_batch_matmul( # pylint: disable=invalid-name,missing-docstring + batch: int, + seq_len: int, + n_head: int, + n_dim: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + query = te.placeholder((batch, seq_len, n_head, n_dim), name="query") + value = te.placeholder((batch, seq_len, n_head, n_dim), name="value") + query_T = te.compute( + (batch, n_head, seq_len, n_dim), + lambda b, h, l, d: query[b, l, h, d], + name="query_T", + ) + value_T = te.compute( + (batch, n_head, n_dim, seq_len), + lambda b, h, d, l: value[b, l, h, d], + name="value_T", + ) + k = te.reduce_axis((0, n_dim), name="k") + out = te.compute( + (batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b, h, i, k] * value_T[b, h, k, j], axis=[k]), + name="C", + ) + return (query, value, out) + + +def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + tile_size = 4 # _infer_tile_size(data, kernel) + inputs = te.placeholder((N, H, W, CI), name="inputs") + N, H, W, CI = topi.utils.get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" + + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r, "float32") + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + _rkh = te.reduce_axis((0, KH), name="r_kh") + _rkw = te.reduce_axis((0, KW), name="r_kw") + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps][ + idxmod(p, nW) * m + nu + ][ci], + name="input_tile", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + data_pack = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: te.sum( + input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}, + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + bgemm = te.compute( + (alpha, alpha, P, CO), + lambda eps, nu, p, co: te.sum( + data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci] + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + inverse = te.compute( + (m, m, P, CO), + lambda vh, vw, p, co: te.sum( + bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] + ), + name="inverse", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}, + ) + + # output + output = te.compute( + (N, H, W, CO), + lambda n, h, w, co: inverse[ + idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co + ], + name="conv2d_winograd", + ) + + return (inputs, kernel_pack, output) + + +def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((k, m), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + return (a, b, c) + + +def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((m, k), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def conv2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + return (x, w, y) + + +def conv2d_nchw_bias_bn_relu( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + oh = (h + 2 * padding - (kh - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + ow = (w + 2 * padding - (kw - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + b = te.placeholder((co, 1, 1), name="B") + bn_scale = te.placeholder((co, 1, 1), name="bn_scale") + bn_offset = te.placeholder((co, 1, 1), name="bn_offset") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + y = te.compute((n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + b[j, 0, 0], name="bias_add") + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] * bn_scale[j, 0, 0], name="bn_mul" + ) + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + bn_offset[j, 0, 0], name="bn_add" + ) + y = topi.nn.relu(y) + return (x, w, b, bn_scale, bn_offset, y) + + + +def max_pool2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + padding: int, +) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + y = topi.nn.pool2d(x, [2, 2], [1, 1], [1, 1], [padding, padding, padding, padding], "max") + return (x, y) + + +def create_te_workload(name: str, idx: int) -> tir.PrimFunc: + workload_func, params = CONFIGS[name] + return te.create_prim_func(workload_func(*params[idx])) # type: ignore + + +CONFIGS = { + "C1D": ( + conv1d_nlc, + [ + # derived from conv2d_shapes + (1, 256, 64, 128, 3, 2, 1), + # (1, 256, 64, 128, 1, 2, 0), + # (1, 256, 64, 64, 1, 1, 0), + # (1, 128, 128, 256, 3, 2, 1), + (1, 128, 128, 256, 1, 2, 0), + # (1, 128, 128, 128, 3, 1, 1), + # (1, 64, 256, 512, 3, 2, 1), + # (1, 64, 256, 512, 1, 2, 0), + (1, 64, 256, 256, 5, 1, 2), + (1, 32, 512, 512, 3, 1, 1), + ], + ), + "C2D": ( + conv2d_nhwc, + [ + # all conv2d layers in resnet-18 + (1, 224, 224, 3, 64, 7, 2, 3), + # (1, 56, 56, 64, 128, 3, 2, 1), + # (1, 56, 56, 64, 128, 1, 2, 0), + # (1, 56, 56, 64, 64, 3, 1, 1), + (1, 56, 56, 64, 64, 1, 1, 0), + # (1, 28, 28, 128, 256, 3, 2, 1), + # (1, 28, 28, 128, 256, 1, 2, 0), + # (1, 28, 28, 128, 128, 3, 1, 1), + # (1, 14, 14, 256, 512, 3, 2, 1), + # (1, 14, 14, 256, 512, 1, 2, 0), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "C3D": ( + conv3d_ndhwc, + [ + # Derived from conv2d_shapes. Use depth=16 for all configurations + (1, 16, 224, 224, 3, 64, 7, 2, 3), + # (1, 16, 56, 56, 64, 128, 3, 2, 1), + # (1, 16, 56, 56, 64, 128, 1, 2, 0), + # (1, 16, 56, 56, 64, 64, 3, 1, 1), + (1, 16, 56, 56, 64, 64, 1, 1, 0), + # (1, 16, 28, 28, 128, 256, 3, 2, 1), + # (1, 16, 28, 28, 128, 256, 1, 2, 0), + # (1, 16, 28, 28, 128, 128, 3, 1, 1), + # (1, 16, 14, 14, 256, 512, 3, 2, 1), + # (1, 16, 14, 14, 256, 512, 1, 2, 0), + (1, 16, 14, 14, 256, 256, 3, 1, 1), + (1, 16, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "GMM": ( + batch_matmul_nkkm, + [ + (1, 128, 128, 128), + (1, 512, 32, 512), + (1, 512, 512, 512), + (1, 1024, 1024, 1024), + ], + ), + "GRP": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use group=4 for all configurations + (1, 56, 56, 64, 128, 3, 2, 1, 1, 4), + # (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), + # (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), + (1, 56, 56, 64, 64, 1, 1, 0, 1, 4), + # (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), + # (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), + # (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), + # (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), + # (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), + (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), + (1, 7, 7, 512, 512, 3, 1, 1, 1, 4), + ], + ), + "DIL": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use dilation=2 for all configurations + (1, 224, 224, 3, 64, 7, 2, 3, 2), + # (1, 56, 56, 64, 128, 3, 2, 1 , 2), + # (1, 56, 56, 64, 128, 1, 2, 0 , 2), + # (1, 56, 56, 64, 64, 3, 1, 1 , 2), + (1, 56, 56, 64, 64, 1, 1, 0, 2), + # (1, 28, 28, 128, 256, 3, 2, 1, 2), + # (1, 28, 28, 128, 256, 1, 2, 0, 2), + # (1, 28, 28, 128, 128, 3, 1, 1, 2), + # (1, 14, 14, 256, 512, 3, 2, 1, 2), + # (1, 14, 14, 256, 512, 1, 2, 0, 2), + (1, 14, 14, 256, 256, 3, 1, 1, 2), + (1, 7, 7, 512, 512, 3, 1, 1, 2), + ], + ), + "DEP": ( + depthwise_conv2d_nhwc, + [ + # all depthwise conv2d layers in mobilenet + (1, 112, 112, 32, 3, 1, 1), + (1, 112, 112, 64, 3, 2, 1), + # (1, 56, 56, 128, 3, 1, 1), + # (1, 56, 56, 128, 3, 2, 1), + # (1, 28, 28, 256, 3, 1, 1), + # (1, 28, 28, 256, 3, 2, 1), + # (1, 14, 14, 512, 3, 1, 1), + (1, 14, 14, 512, 3, 2, 1), + (1, 7, 7, 1024, 3, 1, 1), + ], + ), + "T2D": ( + conv2d_transpose_nhwc, + [ + # all conv2d tranpose layers in DCGAN + (1, 4, 4, 512, 256, 4, 2, 1), + (1, 8, 8, 256, 128, 4, 2, 1), + (1, 16, 16, 128, 64, 4, 2, 1), + (1, 32, 32, 64, 3, 4, 2, 1), + ], + ), + "CAP": ( + conv2d_capsule_nhwijc, + [ + # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) + (1, 16, 16, 32, 32, 3, 2, 1), + (1, 8, 8, 32, 32, 3, 1, 1), + (1, 16, 16, 8, 16, 3, 2, 1), + (1, 8, 8, 16, 16, 3, 1, 1), + ], + ), + "NRM": ( + norm_bmn, + [ + (1, 256, 256), + (1, 512, 512), + (1, 1024, 1024), + (1, 4096, 1024), + ], + ), + "C2d-BN-RELU": ( + conv2d_nhwc_bn_relu, + [ + (1, 224, 224, 3, 64, 7, 2, 3), + (1, 56, 56, 64, 128, 3, 2, 1), + (1, 28, 28, 128, 256, 1, 2, 0), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "TBG": ( + transpose_batch_matmul, + [ + (1, 128, 12, 64), + (1, 128, 16, 64), + (1, 64, 12, 128), + (1, 128, 12, 128), + ], + ), +} diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py new file mode 100644 index 0000000000..14412b076a --- /dev/null +++ b/python/tvm/meta_schedule/tune.py @@ -0,0 +1,450 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""User-facing Tuning API""" + +from contextlib import contextmanager +import logging +import os.path +import tempfile +from typing import Callable, Generator, List, Optional, Union + +from tvm.ir.module import IRModule +from tvm.target.target import Target +from tvm.te import Tensor, create_prim_func +from tvm.tir import PrimFunc, Schedule + +from . import schedule_rule +from . import measure_callback +from . import postproc +from .builder import Builder, LocalBuilder +from .database import Database, JSONDatabase, TuningRecord +from .measure_callback import MeasureCallback +from .runner import LocalRunner, Runner +from .search_strategy import ReplayFuncConfig, ReplayTraceConfig +from .space_generator import PostOrderApply +from .task_scheduler import RoundRobin, TaskScheduler +from .tune_context import TuneContext + + +logger = logging.getLogger(__name__) + + +SearchStrategyConfig = Union[ + ReplayFuncConfig, + ReplayTraceConfig, +] + +TYPE_F_TUNE_CONTEXT = Callable[ # pylint: disable=invalid-name + [ + IRModule, + Target, + SearchStrategyConfig, + str, + ], + TuneContext, +] + +TYPE_F_TASK_SCHEDULER = Callable[ # pylint: disable=invalid-name + [ + List[TuneContext], + Builder, + Runner, + Database, + List[MeasureCallback], + ], + TaskScheduler, +] + + +def _parse_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: + if isinstance(mod, PrimFunc): + mod = mod.with_attr("global_symbol", "main") + mod = mod.with_attr("tir.noalias", True) + mod = IRModule({"main": mod}) + if not isinstance(mod, IRModule): + raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") + return mod + + +def _parse_target(target: Union[str, Target]) -> Target: + if isinstance(target, str): + target = Target(target) + if not isinstance(target, Target): + raise TypeError(f"Expected `target` to be str or Target, but gets: {target}") + return target + + +@contextmanager +def _work_dir_context(work_dir: Optional[str]) -> Generator[str, None, None]: + if work_dir is not None and not os.path.isdir(work_dir): + raise ValueError(f"`work_dir` must be a directory, but gets: {work_dir}") + temp_dir = None + try: + if work_dir is not None: + yield work_dir + else: + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir.name + finally: + if temp_dir is not None: + temp_dir.cleanup() + + +def _parse_builder(builder: Optional[Builder]) -> Builder: + if builder is None: + builder = LocalBuilder() + if not isinstance(builder, Builder): + raise TypeError(f"Expected `builder` to be Builder, but gets: {builder}") + return builder + + +def _parse_runner(runner: Optional[Runner]) -> Runner: + if runner is None: + runner = LocalRunner() + if not isinstance(runner, Runner): + raise TypeError(f"Expected `runner` to be Runner, but gets: {runner}") + return runner + + +def _parse_database(database: Optional[Database], path: str) -> Database: + if database is None: + database = JSONDatabase( + path_workload=os.path.join(path, "workload.json"), + path_tuning_record=os.path.join(path, "tuning_record.json"), + ) + if not isinstance(database, Database): + raise TypeError(f"Expected `database` to be Database, but gets: {database}") + return database + + +def _parse_measure_callbacks( + measure_callbacks: Optional[List[MeasureCallback]], +) -> List[MeasureCallback]: + if measure_callbacks is None: + measure_callbacks = [ + measure_callback.AddToDatabase(), + measure_callback.RemoveBuildArtifact(), + measure_callback.EchoStatistics(), + ] + if not isinstance(measure_callbacks, (list, tuple)): + raise TypeError( + f"Expected `measure_callbacks` to be List[MeasureCallback], " + f"but gets: {measure_callbacks}" + ) + measure_callbacks = list(measure_callbacks) + for i, callback in enumerate(measure_callbacks): + if not isinstance(callback, MeasureCallback): + raise TypeError( + f"Expected `measure_callbacks` to be List[MeasureCallback], " + f"but measure_callbacks[{i}] is: {callback}" + ) + return measure_callbacks + + +def _parse_f_tune_context(f_tune_context: Optional[TYPE_F_TUNE_CONTEXT]) -> TYPE_F_TUNE_CONTEXT: + def default_llvm( + mod: IRModule, + target: Target, + config: SearchStrategyConfig, + task_name: str, + ) -> TuneContext: + return TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + search_strategy=config.create_strategy(), + sch_rules=[ + schedule_rule.AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + schedule_rule.MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_max_len=None, + reuse_read=None, + reuse_write=schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=32, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + ], + postprocs=[ + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + ], + mutators=[], + task_name=task_name, + rand_state=-1, + num_threads=None, + ) + + def default_cuda( + mod: IRModule, + target: Target, + config: SearchStrategyConfig, + task_name: str, + ) -> TuneContext: + return TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + search_strategy=config.create_strategy(), + sch_rules=[ + schedule_rule.AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + schedule_rule.MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_max_len=4, + reuse_read=schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=schedule_rule.ReuseType( + req="must", + levels=[3], + scope="local", + ), + ), + schedule_rule.AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, # disable parallelize + max_vectorize_extent=-1, # disable vectorize + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ), + ], + postprocs=[ + postproc.RewriteCooperativeFetch(), + postproc.RewriteUnboundBlock(), + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.VerifyGPUCode(), + ], + mutators=[], + task_name=task_name, + rand_state=-1, + num_threads=None, + ) + + def default( + mod: IRModule, + target: Target, + config: SearchStrategyConfig, + task_name: str, + ) -> TuneContext: + if target.kind.name == "llvm": + return default_llvm(mod, target, config, task_name) + if target.kind.name == "cuda": + return default_cuda(mod, target, config, task_name) + raise NotImplementedError(f"Unsupported target: {target.kind.name}") + + if f_tune_context is None: + return default + return f_tune_context + + +def _parse_f_task_scheduler( + f_task_scheduler: Optional[TYPE_F_TASK_SCHEDULER], +) -> TYPE_F_TASK_SCHEDULER: + def default( + tasks: List[TuneContext], + builder: Builder, + runner: Runner, + database: Database, + measure_callbacks: List[MeasureCallback], + ) -> TaskScheduler: + return RoundRobin( + tasks=tasks, + builder=builder, + runner=runner, + database=database, + measure_callbacks=measure_callbacks, + ) + + if f_task_scheduler is None: + return default + return f_task_scheduler + + +def tune_tir( + mod: Union[IRModule, PrimFunc], + target: Union[str, Target], + config: SearchStrategyConfig, + *, + task_name: str = "main", + work_dir: Optional[str] = None, + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + f_tune_context: Optional[TYPE_F_TUNE_CONTEXT] = None, + f_task_scheduler: Optional[TYPE_F_TASK_SCHEDULER] = None, +) -> Optional[Schedule]: + """Tune a TIR IRModule with a given target. + + Parameters + ---------- + mod : Union[IRModule, PrimFunc] + The module to tune. + target : Union[str, Target] + The target to tune for. + config : SearchStrategyConfig + The search strategy config. + task_name : str + The name of the task. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] + The function to create TuneContext. + f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] + The function to create TaskScheduler. + + Returns + ------- + sch : Optional[Schedule] + The tuned schedule. + """ + + with _work_dir_context(work_dir) as path: + logger.info("Working directory: %s", path) + mod = _parse_mod(mod) + target = _parse_target(target) + builder = _parse_builder(builder) + runner = _parse_runner(runner) + database = _parse_database(database, path) + measure_callbacks = _parse_measure_callbacks(measure_callbacks) + tune_context = _parse_f_tune_context(f_tune_context)(mod, target, config, task_name) + task_scheduler = _parse_f_task_scheduler(f_task_scheduler)( + [tune_context], + builder, + runner, + database, + measure_callbacks, + ) + task_scheduler.tune() + workload = database.commit_workload(mod) + bests: List[TuningRecord] = database.get_top_k(workload, top_k=1) + if not bests: + return None + assert len(bests) == 1 + sch = Schedule(mod) + bests[0].trace.apply_to_schedule(sch, remove_postproc=False) + return sch + + +def tune_te( + tensors: List[Tensor], + target: Union[str, Target], + config: SearchStrategyConfig, + *, + task_name: str = "main", + work_dir: Optional[str] = None, + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + f_tune_context: Optional[TYPE_F_TUNE_CONTEXT] = None, + f_task_scheduler: Optional[TYPE_F_TASK_SCHEDULER] = None, +) -> Optional[Schedule]: + """Tune a TE compute DAG with a given target. + + Parameters + ---------- + tensor : List[Tensor] + The list of input/output tensors of the TE compute DAG. + target : Union[str, Target] + The target to tune for. + config : SearchStrategyConfig + The search strategy config. + task_name : str + The name of the task. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] + The function to create TuneContext. + f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] + The function to create TaskScheduler. + + Returns + ------- + sch : Optional[Schedule] + The tuned schedule. + """ + return tune_tir( + mod=create_prim_func(tensors), + target=target, + config=config, + task_name=task_name, + work_dir=work_dir, + builder=builder, + runner=runner, + database=database, + measure_callbacks=measure_callbacks, + f_tune_context=f_tune_context, + f_task_scheduler=f_task_scheduler, + ) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 0f3cfac1a8..196b1c16b6 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,19 +16,23 @@ # under the License. """Meta Schedule tuning context.""" -from typing import Optional, TYPE_CHECKING +from typing import Optional, List, Dict, TYPE_CHECKING from tvm import IRModule from tvm._ffi import register_object from tvm.meta_schedule.utils import cpu_count from tvm.runtime import Object from tvm.target import Target +from tvm.tir import PrimFunc from . import _ffi_api if TYPE_CHECKING: from .space_generator import SpaceGenerator from .search_strategy import SearchStrategy + from .schedule_rule import ScheduleRule + from .postproc import Postproc + from .mutator import Mutator @register_object("meta_schedule.TuneContext") @@ -50,6 +54,12 @@ class TuneContext(Object): The design space generator. search_strategy : Optional[SearchStrategy] = None The search strategy. + sch_rules: Optional[List[ScheduleRule]] = None, + The schedule rules. + postprocs: Optional[List[Postproc"]] = None, + The postprocessors. + mutator_probs: Optional[Dict[Mutator, float]] + Mutators and their probability mass. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -68,42 +78,31 @@ class TuneContext(Object): mod: Optional[IRModule] target: Optional[Target] - space_generator: "SpaceGenerator" - search_strategy: "SearchStrategy" - task_name: Optional[str] + space_generator: Optional["SpaceGenerator"] + search_strategy: Optional["SearchStrategy"] + sch_rules: List["ScheduleRule"] + postprocs: List["Postproc"] + mutator_probs: Optional[Dict["Mutator", float]] + task_name: str rand_state: int num_threads: int def __init__( self, mod: Optional[IRModule] = None, + *, target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, - task_name: Optional[str] = None, + sch_rules: Optional[List["ScheduleRule"]] = None, + postprocs: Optional[List["Postproc"]] = None, + mutator_probs: Optional[Dict["Mutator", float]] = None, + task_name: str = "main", rand_state: int = -1, num_threads: Optional[int] = None, ): - """Constructor. - - Parameters - ---------- - mod : Optional[IRModule] = None - The workload to be optimized. - target : Optional[Target] = None - The target to be optimized for. - space_generator : Optional[SpaceGenerator] = None - The design space generator. - search_strategy : Optional[SearchStrategy] = None - The search strategy. - task_name : Optional[str] = None - The name of the tuning task. - rand_state : int = -1 - The random state. - Need to be in integer in [1, 2^31-1], -1 means using random number. - num_threads : Optional[int] = None - The number of threads to be used, None means using the logical cpu count. - """ + if isinstance(mod, PrimFunc): + mod = IRModule.from_expr(mod) if num_threads is None: num_threads = cpu_count() @@ -113,6 +112,9 @@ def __init__( target, space_generator, search_strategy, + sch_rules, + postprocs, + mutator_probs, task_name, rand_state, num_threads, diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index a9ef514543..64a2965479 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. """Utilities for meta schedule""" +from typing import Any, Callable, List, Optional, Union + +import ctypes import json import os import shutil -from typing import Any, Callable, List, Optional, Union - import psutil # type: ignore + import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError @@ -31,7 +33,7 @@ @register_func("meta_schedule.cpu_count") -def cpu_count(logical: bool = True) -> int: +def _cpu_count_impl(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system Parameters @@ -59,6 +61,22 @@ def cpu_count(logical: bool = True) -> int: return psutil.cpu_count(logical=logical) or 1 +def cpu_count(logical: bool = True) -> int: + """Return the number of logical or physical CPUs in the system + + Parameters + ---------- + logical : bool = True + If True, return the number of logical CPUs, otherwise return the number of physical CPUs + + Returns + ------- + cpu_count : int + The number of logical or physical CPUs in the system + """ + return _cpu_count_impl(logical) + + def get_global_func_with_default_on_worker( name: Union[None, str, Callable], default: Callable, @@ -207,6 +225,22 @@ def structural_hash(mod: IRModule) -> str: return str(shash) +def _get_hex_address(handle: ctypes.c_void_p) -> str: + """Get the hexadecimal address of a handle. + + Parameters + ---------- + handle : ctypes.c_void_p + The handle to be converted. + + Returns + ------- + result : str + The hexadecimal address of the handle. + """ + return hex(ctypes.cast(handle, ctypes.c_void_p).value) + + def check_override( derived_class: Any, base_class: Any, required: bool = True, func_name: str = None ) -> Callable: diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 09b847a3ba..2deb4f25d9 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -271,13 +271,17 @@ def _module_export(module, file_name): # fcompile, addons, kwargs? @register_func("tvm.relay.build") +def _build_module_no_factory_impl(mod, target, target_host, params, mod_name): + target, target_host = Target.check_and_update_host_consist(target, target_host) + return build(mod, target, params=params, mod_name=mod_name).module + + def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): """A wrapper around build which discards the Python GraphFactoryRuntime. This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - target, target_host = Target.check_and_update_host_consist(target, target_host) - return build(mod, target, params=params, mod_name=mod_name).module + return _build_module_no_factory_impl(mod, target, target_host, params, mod_name) def _reconstruct_from_deprecated_options(deprecated_params_target): diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 4750ad7626..e5615e6cc5 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -20,7 +20,7 @@ import synr import tvm.tir -from tvm.runtime import Object +from tvm.runtime import Object, String from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind @@ -485,7 +485,7 @@ def create_loop_info( self.annotations: Mapping[str, Object] = {} if annotations is not None: self.annotations = { - key: tvm.tir.StringImm(val) if isinstance(val, str) else val + key: String(val) if isinstance(val, str) else val for key, val in annotations.items() } diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 3b69dd08d4..a0c30ef954 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -24,7 +24,7 @@ from tvm.ir.expr import PrimExpr, Range import tvm.tir -from tvm.runtime import Object +from tvm.runtime import Object, String from tvm import te from tvm.ir import Span from tvm.tir import IntImm, IterVar @@ -408,7 +408,7 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None): span, ) attrs = { - key: tvm.tir.StringImm(val) if isinstance(val, str) else val + key: String(val) if isinstance(val, str) else val for key, val in attrs.items() } block_scope.annotations = attrs diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 428403a98f..80ee286f57 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize -from .function import PrimFunc +from .function import PrimFunc, IndexMap, TensorIntrin from .op import call_packed, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index ecbcd837cb..b41eb97b59 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,18 +16,19 @@ # under the License. """Function data types.""" -from typing import Mapping, Union +import inspect +from typing import Callable, List, Mapping, Union -import tvm._ffi -import tvm.runtime -from tvm.runtime import Object +from tvm._ffi import get_global_func, register_object from tvm.ir import BaseFunc -from .buffer import Buffer -from .expr import Var, PrimExpr +from tvm.runtime import Object, convert + from . import _ffi_api +from .buffer import Buffer +from .expr import PrimExpr, Var -@tvm._ffi.register_object("tir.PrimFunc") +@register_object("tir.PrimFunc") class PrimFunc(BaseFunc): """A function declaration expression. @@ -56,7 +57,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: - x = tvm.runtime.convert(x) if not isinstance(x, Object) else x + x = convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): var = Var(x.name, dtype="handle") param_list.append(var) @@ -67,7 +68,13 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore + _ffi_api.PrimFunc, # type: ignore # pylint: disable=no-member + param_list, + body, + ret_type, + buffer_map, + attrs, + span, ) def with_body(self, new_body, span=None): @@ -141,7 +148,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: func : PrimFunc The new function with parameter specialized """ - return _ffi_api.Specialize(self, param_map) # type: ignore + return _ffi_api.Specialize(self, param_map) # type: ignore # pylint: disable=no-member def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: """Print IRModule into TVMScript @@ -159,6 +166,95 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: script : str The TVM Script of the PrimFunc """ - return tvm._ffi.get_global_func("script.AsTVMScript")( - self, tir_prefix, show_meta - ) # type: ignore + return get_global_func("script.AsTVMScript")(self, tir_prefix, show_meta) # type: ignore + + +@register_object("tir.IndexMap") +class IndexMap(Object): + """A mapping from multi-dimensional indices to another set of multi-dimensional indices + + Parameters + ---------- + src_iters : list of Var + The source indices + tgt_iters : list of PrimExpr + The target indices + """ + + src_iters: List[Var] + """The source indices""" + + tgt_iters: List[PrimExpr] + """The target indices""" + + def __init__(self, src_iters: List[Var], tgt_iters: List[PrimExpr]): + self._init_handle_by_constructor( + _ffi_api.IndexMap, # type: ignore # pylint: disable=no-member + src_iters, + tgt_iters, + ) + + def apply(self, indices: List[PrimExpr]) -> List[PrimExpr]: + """Apply the index map to a set of indices + + Parameters + ---------- + indices : List[PriExpr] + The indices to be mapped + + Returns + ------- + result : List[PrimExpr] + The mapped indices + """ + return _ffi_api.IndexMapApply(self, indices) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_func(func: Callable) -> "IndexMap": + """Create an index map from a function + + Parameters + ---------- + func : Callable + The function to map from source indices to target indices + """ + + def wrap(args: List[Var]) -> List[PrimExpr]: + result = func(*args) + if isinstance(result, tuple): + return list(result) + if not isinstance(result, list): + result = [result] + return result + + ndim = len(inspect.signature(func).parameters) + return _ffi_api.IndexMapFromFunc(ndim, wrap) # type: ignore # pylint: disable=no-member + + +@register_object("tir.TensorIntrin") +class TensorIntrin(Object): + """A function declaration expression. + + Parameters + ---------- + desc_func: PrimFunc + The function to describe the computation + + intrin_func: PrimFunc + The function for execution + """ + + def __init__(self, desc_func, intrin_func): + self.__init_handle_by_constructor__( + _ffi_api.TensorIntrin, desc_func, intrin_func # type: ignore # pylint: disable=no-member + ) + + @staticmethod + def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc): + return _ffi_api.TensorIntrinRegister( # pylint: disable=no-member + name, desc_func, intrin_func + ) + + @staticmethod + def get(name: str): + return _ffi_api.TensorIntrinGet(name) # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 5f0e169c43..66ac7b9d77 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -22,3 +22,5 @@ from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState from .trace import Trace + +from . import analysis diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py new file mode 100644 index 0000000000..7c0c77a372 --- /dev/null +++ b/python/tvm/tir/schedule/analysis.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Analysis used in TensorIR scheduling""" +from typing import List, Optional + +from ..buffer import Buffer +from ..stmt import For +from ..expr import PrimExpr +from ..function import IndexMap + +from . import _ffi_api + + +def suggest_index_map( + buffer: Buffer, + indices: List[PrimExpr], + loops: List[For], + predicate: PrimExpr, +) -> Optional[IndexMap]: + """Provided the access pattern to a buffer, suggest one of the possible layout + transformation to minimize the locality of the access pattern. + + Parameters + ---------- + buffer : Buffer + The buffer to be transformed. + indices : List[PrimExpr] + The access pattern to the buffer. + loops : List[For] + The loops above the buffer. + predicate : PrimExpr + The predicate of the access. + + Returns + ------- + index_map : Optional[IndexMap] + The suggested index map. None if no transformation is suggested. + """ + return _ffi_api.SuggestIndexMap( # type: ignore # pylint: disable=no-member + buffer, + indices, + loops, + predicate, + ) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 884eeb7c61..c5871a53eb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -20,8 +20,9 @@ from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr -from tvm.runtime import Object -from tvm.tir import Block, For, IntImm, PrimFunc +from tvm.runtime import Object, String +from tvm.tir import Block, For, IntImm, PrimFunc, TensorIntrin +from tvm.tir.expr import FloatImm from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod @@ -358,6 +359,31 @@ def sample_perfect_tile( decision, ) + def sample_compute_location( + self, + block: BlockRV, + decision: Optional[int] = None, + ) -> LoopRV: + """Sample a compute-at location on a BlockRV so that its producer can compute at that loop + + Parameters + ---------- + block : BlockRV + The consumer block to be computed at + decision : Optional[int] + The sampling decision + + Returns + ------- + result : LoopRV + The sampled loop to be computed at + """ + return _ffi_api.ScheduleSampleComputeLocation( # pylint: disable=no-member + self, + block, + decision, + ) + ########## Schedule: Get blocks & loops ########## def get_block( self, @@ -1004,6 +1030,30 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + ########## Schedule: Data movement ########## + + def read_at( + self, + loop: LoopRV, + block: BlockRV, + read_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member + self, loop, block, read_buffer_index, storage_scope + ) + + def write_at( + self, + loop: LoopRV, + block: BlockRV, + write_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member + self, loop, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## def compute_at( @@ -1630,8 +1680,55 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: ########## Schedule: Blockize & Tensorize ########## + def blockize(self, loop: LoopRV) -> BlockRV: + return _ffi_api.ScheduleBlockize(self, loop) # pylint: disable=no-member + + def tensorize(self, loop: LoopRV, intrin: Union[str, TensorIntrin]) -> None: + if isinstance(intrin, str): + intrin = String(intrin) + _ffi_api.ScheduleTensorize(self, loop, intrin) # pylint: disable=no-member + ########## Schedule: Annotation ########## + def annotate( + self, + block_or_loop: Union[BlockRV, LoopRV], + ann_key: str, + ann_val: Union[str, int, float, ExprRV], + ) -> None: + """Annotate a block/loop with a key value pair + + Parameters + ---------- + block_or_loop: Union[BlockRV, LoopRV] + The block/loop to be annotated + ann_key : str + The annotation key + ann_val : Union[str, int, float, ExprRV] + The annotation value + """ + if isinstance(ann_val, str): + ann_val = String(ann_val) + elif isinstance(ann_val, int): + ann_val = IntImm("int32", ann_val) + elif isinstance(ann_val, float): + ann_val = FloatImm("float32", ann_val) + _ffi_api.ScheduleAnnotate( # pylint: disable=no-member + self, block_or_loop, ann_key, ann_val + ) + + def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None: + """Unannotate a block/loop's annotation with key ann_key + + Parameters + ---------- + block_or_loop: Union[BlockRV, LoopRV] + The block/loop to be unannotated + ann_key : str + The annotation key + """ + _ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member + ########## Schedule: Misc ########## def enter_postproc(self) -> None: diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc index fb63b7e653..25adfbfdfa 100644 --- a/src/meta_schedule/builder/builder.cc +++ b/src/meta_schedule/builder/builder.cc @@ -23,10 +23,12 @@ namespace meta_schedule { /******** Constructors ********/ -BuilderInput::BuilderInput(IRModule mod, Target target) { +BuilderInput::BuilderInput(IRModule mod, Target target, + Optional> params) { ObjectPtr n = make_object(); n->mod = std::move(mod); n->target = std::move(target); + n->params = std::move(params); data_ = std::move(n); } @@ -51,8 +53,9 @@ TVM_REGISTER_OBJECT_TYPE(BuilderNode); TVM_REGISTER_NODE_TYPE(PyBuilderNode); TVM_REGISTER_GLOBAL("meta_schedule.BuilderInput") - .set_body_typed([](IRModule mod, Target target) -> BuilderInput { - return BuilderInput(mod, target); + .set_body_typed([](IRModule mod, Target target, + Optional> params) -> BuilderInput { + return BuilderInput(mod, target, params); }); TVM_REGISTER_GLOBAL("meta_schedule.BuilderResult") diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc new file mode 100644 index 0000000000..5cd32b097c --- /dev/null +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_load = std::move(f_load); + n->f_save = std::move(f_save); + n->f_update = std::move(f_update); + n->f_predict = std::move(f_predict); + n->f_as_string = std::move(f_as_string); + return CostModel(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyCostModelNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(CostModelNode); +TVM_REGISTER_NODE_TYPE(PyCostModelNode); + +TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate") + .set_body_method(&CostModelNode::Update); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") + .set_body_typed([](CostModel model, // + const TuneContext& tune_context, // + Array candidates, // + void* p_addr) -> void { + std::vector result = model->Predict(tune_context, candidates); + std::copy(result.begin(), result.end(), static_cast(p_addr)); + }); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index e67b3d1ab9..fc7cc74de5 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -135,10 +135,12 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w /******** PyDatabase ********/ -Database Database::PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, +Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, + PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size) { ObjectPtr n = make_object(); + n->f_has_workload = f_has_workload; n->f_commit_workload = f_commit_workload; n->f_commit_tuning_record = f_commit_tuning_record; n->f_get_top_k = f_get_top_k; @@ -166,6 +168,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") .set_body_method(&TuningRecordNode::AsJSON); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") + .set_body_method(&DatabaseNode::HasWorkload); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") .set_body_method(&DatabaseNode::CommitWorkload); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 3efb72e2fa..2e76940fee 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -69,6 +69,10 @@ class JSONDatabaseNode : public DatabaseNode { TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); public: + bool HasWorkload(const IRModule& mod) { + return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) != workloads2idx_.end(); + } + Workload CommitWorkload(const IRModule& mod) { // Try to insert `mod` into `workloads_` decltype(this->workloads2idx_)::iterator it; diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc new file mode 100644 index 0000000000..84d22493aa --- /dev/null +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +FeatureExtractor FeatureExtractor::PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, // + PyFeatureExtractorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_extract_from = std::move(f_extract_from); + n->f_as_string = std::move(f_as_string); + return FeatureExtractor(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyFeatureExtractorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyFeatureExtractor's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); +TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") + .set_body_method(&FeatureExtractorNode::ExtractFrom); +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") + .set_body_typed(FeatureExtractor::PyFeatureExtractor); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc new file mode 100644 index 0000000000..2081976d2c --- /dev/null +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -0,0 +1,1282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include +#include +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +/*! \brief Type for multi-dimensional index */ +using MultiIndex = std::vector; +/*! \brief Vector of int64_t */ +using IntVec = std::vector; +/*! \brief Vector of for loops */ +using ForVec = std::vector; + +/*! + * \brief An unordered_map for (for, buffer) => V + * \tparam V The value type + */ +template +using ForBufferMap = std::unordered_map>; + +/*! \brief Given x, compute log2(|x| + 1) */ +inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x + 1); } + +namespace utils { + +/*! + * \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if it exists + * \param loop The loop to be checked + * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if it does not exist + */ +int64_t GetPragmaAutoUnroll(const ForNode* loop) { + if (Optional auto_unroll = GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { + return auto_unroll.value()->value; + } + return -1; +} + +/*! + * \brief Given a list of loops, return the extent of the first loop if the list is not empty, + * and the first loop has constant extent. Otherwise returns the default value given + * \param loops The list of loops to be checked + * \param default_value The default value to be returned if the list is empty or the first loop + * does not have constant extent + * \return The extent of the first loop if the list is not empty, or the first loop has constant + * extent. Otherwise returns the default value + */ +int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) { + if (!loops.empty()) { + if (const int64_t* extent = GetLoopIntExtent(loops[0])) { + return *extent; + } + } + return default_value; +} + +/*! + * \brief Relax each of the multi-indexing pattern according to the domains bound in the analyzer, + * and then union them into a single region + * \param multi_index_pattern A list of multi-index pattern to be relaxed + * \param numel The size of the single region after union + * \param analyzer The analyzer that contains the domain information + * \return The relaxed and unioned region + */ +IntVec RelaxAndUnion(const std::vector& multi_indices, int64_t* numel, + arith::Analyzer* analyzer) { + if (multi_indices.empty()) { + return {}; + } + int n_indices = multi_indices.size(); + int ndim = multi_indices[0].size(); + IntVec access_shape(ndim, 0); + for (int i = 0; i < ndim; ++i) { + int64_t minimum = arith::ConstIntBound::kPosInf; + int64_t maximum = arith::ConstIntBound::kNegInf; + for (int j = 0; j < n_indices; ++j) { + arith::ConstIntBound bound = analyzer->const_int_bound(multi_indices[j][i]); + minimum = std::min(minimum, bound->min_value); + maximum = std::max(maximum, bound->max_value); + } + *numel *= maximum - minimum + 1; + access_shape[i] = maximum - minimum + 1; + } + return access_shape; +} + +/*! + * \brief Given a list of multi-index pattern, return the minimal stride of a variable on it + * \param multi_indices The list of multi-index pattern + * \param buffer_stride The stride of the buffer + * \param var The variable to be checked + * \return The minimal stride of the variable on the multi-index pattern + */ +int64_t GetVarStride(const std::vector& multi_indices, const IntVec& buffer_stride, + const Var& var) { + class CoefficientExtractor : private ExprVisitor { + public: + static int64_t Extract(const PrimExpr& expr, const Var& var) { + CoefficientExtractor extractor(var); + extractor.VisitExpr(expr); + return (extractor.visited_var && !extractor.visited_mul && !extractor.visited_add) + ? 1 + : (extractor.visited_var ? extractor.stride : 0); + } + + private: + explicit CoefficientExtractor(const Var& var) + : var(var), stride(0), visited_var(false), visited_add(false), visited_mul(false) {} + + void VisitExpr_(const MulNode* node) override { + ExprVisitor::VisitExpr_(node); + if (visited_var && !visited_add) { + if (const auto* a = node->a.as()) { + visited_mul = true; + stride = a->value; + } else if (const auto* b = node->b.as()) { + visited_mul = true; + stride = b->value; + } + } + } + + void VisitExpr_(const AddNode* node) override { + ExprVisitor::VisitExpr_(node); + if (visited_var && !visited_mul) { + visited_add = true; + stride = 1; + } + } + + void VisitExpr_(const VarNode* node) override { + if (node == var.get()) { + visited_var = true; + stride = 2; + } + } + + const Var& var; + int64_t stride; + bool visited_var; + bool visited_add; + bool visited_mul; + }; + + constexpr int64_t kNotFound = std::numeric_limits::max(); + int ndim = buffer_stride.size(); + // Calculate the min stride possible + int64_t result = kNotFound; + for (const MultiIndex& multi_index : multi_indices) { + ICHECK_EQ(multi_index.size(), buffer_stride.size()); + // Find the rightest dimension that contains the given variable + for (int i = ndim - 1; i >= 0; --i) { + int64_t coef = CoefficientExtractor::Extract(multi_index[i], var); + if (coef != 0) { + result = std::min(result, std::abs(coef) * buffer_stride[i]); + break; + } + } + } + return (result == kNotFound) ? 0 : result; +} + +/*! + * \brief Converts a 2-dimensional STL vector to a TVM NDArray + * \param src The source 2-dimensional STL vector + * \return The converted TVM NDArray + */ +runtime::NDArray AsNDArray(const std::vector>& src) { + ICHECK(!src.empty()); + int n = src.size(); + int m = src[0].size(); + runtime::NDArray tgt = runtime::NDArray::Empty( + /*shape=*/{n, m}, + /*dtype=*/DLDataType{kDLFloat, 64, 1}, + /*ctx=*/DLDevice{kDLCPU, 0}); + double* data = static_cast(tgt->data); + for (const std::vector& row : src) { + for (double v : row) { + *data++ = v; + } + } + return tgt; +} + +} // namespace utils + +namespace transform { + +/*! + * \brief Create a pass that simplifies the IR for feature extraction + * \return The pass created + */ +Pass SimplifyForFeatureExtraction() { + class Simplifier : private StmtExprMutator { + public: + static Stmt Run(Stmt stmt) { return Simplifier()(std::move(stmt)); } + + private: + PrimExpr VisitExpr_(const SelectNode* node) final { return make_const(node->dtype, 1.0); } + + PrimExpr VisitExpr_(const VarNode* var) final { + if (unit_vars_.count(GetRef(var))) { + return make_const(var->dtype, 0.0); + } + return GetRef(var); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (is_zero(loop->min) && is_one(loop->extent) && loop->kind == ForKind::kSerial && + loop->annotations.empty()) { + unit_vars_.insert(loop->loop_var); + return VisitStmt(loop->body); + } else { + return StmtExprMutator::VisitStmt_(loop); + } + } + + std::unordered_set unit_vars_; + }; + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + PrimFuncNode* n = f.CopyOnWrite(); + n->body = Simplifier::Run(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyConstMatrix", {}); +} + +/*! + * \brief Create a list of passes that preprocesses the IR for feature extraction + * \return The list of passes created + */ +Sequential PassListForPerStoreFeature() { + return Sequential({ + tir::transform::SimplifyForFeatureExtraction(), + tir::transform::LowerCrossThreadReduction(), + tir::transform::LowerInitBlock(), + tir::transform::PlanAndUpdateBufferAllocationLocation(), + tir::transform::ConvertBlocksToOpaque(), + tir::transform::UnifyThreadBinding(), + tir::transform::CompactBufferAllocation(), + tir::transform::LowerMatchBuffer(), + tir::transform::Simplify(), + }); +} + +} // namespace transform + +/*! \brief A data structure managing loop nests */ +struct LoopNest { + int64_t prod = 1; // The product of the extents of all the loops + ForVec loops; // All the loops + IntVec auto_unroll; // The loops with auto unroll pragma + ForVec parallel; // The loops whose ForKind are kParallel + ForVec vectorize; // The loops whose ForKind are kVectorized + ForVec unroll; // The loops whose ForKind are kUnrolled + ForVec blockIdx_x; // The loops whose ForKind are kThreadBinding to blockIdx.x + ForVec blockIdx_y; // The loops whose ForKind are kThreadBinding to blockIdx.y + ForVec blockIdx_z; // The loops whose ForKind are kThreadBinding to blockIdx.z + ForVec threadIdx_x; // The loops whose ForKind are kThreadBinding to threadIdx.x + ForVec threadIdx_y; // The loops whose ForKind are kThreadBinding to threadIdx.y + ForVec threadIdx_z; // The loops whose ForKind are kThreadBinding to threadIdx.z + ForVec vthread; // The loops whose ForKind are kThreadBinding to vthread.* + + /*! + * \brief Push a new loop into the loop nest + * \param loop The loop to be pushed + * \param auto_unroll_attr The auto unroll attribute of the loop + * \return A list of for loops that the loop is bound to + */ + ForVec* Push(const ForNode* loop, int64_t* auto_unroll_attr) { + if (const int64_t* extent = GetLoopIntExtent(loop)) { + this->prod *= *extent; + } + this->loops.push_back(loop); + if ((*auto_unroll_attr = utils::GetPragmaAutoUnroll(loop)) > 0) { + this->auto_unroll.push_back(*auto_unroll_attr); + } + ForVec* ref_loops = nullptr; + if (loop->kind == ForKind::kParallel) { + ref_loops = ∥ + } else if (loop->kind == ForKind::kVectorized) { + ref_loops = &vectorize; + } else if (loop->kind == ForKind::kUnrolled) { + ref_loops = &unroll; + } else if (loop->kind == ForKind::kThreadBinding) { + std::string thread_tag = loop->thread_binding.value()->thread_tag; + if (thread_tag == "blockIdx.x") { + ref_loops = &blockIdx_x; + } else if (thread_tag == "blockIdx.y") { + ref_loops = &blockIdx_y; + } else if (thread_tag == "blockIdx.z") { + ref_loops = &blockIdx_z; + } else if (thread_tag == "threadIdx.x") { + ref_loops = &threadIdx_x; + } else if (thread_tag == "threadIdx.y") { + ref_loops = &threadIdx_y; + } else if (thread_tag == "threadIdx.z") { + ref_loops = &threadIdx_z; + } else if (support::StartsWith(thread_tag, "vthread")) { + ref_loops = &vthread; + } else { + LOG(FATAL) << "ValueError: Unable to recognize thread tag: " << thread_tag; + } + } + if (ref_loops != nullptr) { + ref_loops->push_back(loop); + } + return ref_loops; + } + + /*! + * \brief Pop the last loop from the loop nest + * \param loop The loop to be popped + * \param ref_loops The list of for loops that the loop is bound to + * \param auto_unroll_attr The auto unroll attribute of the loop + */ + void Pop(const ForNode* loop, ForVec* ref_loops, int auto_unroll_attr) { + if (ref_loops) { + ref_loops->pop_back(); + } + if (auto_unroll_attr > 0) { + this->auto_unroll.pop_back(); + } + if (const int64_t* extent = GetLoopIntExtent(loop)) { + this->prod /= *extent; + } + this->loops.pop_back(); + } +}; + +/****** Group 1: Computation related features ******/ + +namespace group1 { + +/*! \brief Group 1 features */ +struct Feature { + /*! \brief Arithmetic features */ + struct ArithOps { + // Float-point arithmetic features + int64_t float_mad = 0; // The number of float MAD (Multiply–add) ops + int64_t float_add_sub = 0; // The number of float add and sub ops + int64_t float_mul = 0; // The number of float multiply ops + int64_t float_div_mod = 0; // The number of float div and mod ops + int64_t float_cmp = 0; // The number of float comparison ops + int64_t float_math_func = 0; // The number of float math func calls + int64_t float_other_func = 0; // The number of other float func calls + // Integer arithmetic features + int64_t int_mad = 0; // The number of integer MAD (Multiply–add) ops + int64_t int_add_sub = 0; // The number of integer add and sub ops + int64_t int_mul = 0; // The number of integer multiply ops + int64_t int_div_mod = 0; // The number of integer div and mod ops + int64_t int_cmp = 0; // The number of integer comparison ops + int64_t int_math_func = 0; // The number of integer math func calls + int64_t int_other_func = 0; // The number of other integer func calls + // Other arithmetic features + int64_t bool_op = 0; // The number of bool ops + int64_t select_op = 0; // The number of select ops + + static constexpr int64_t kCount = 16; + + ArithOps() = default; + ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent); + + void Export(std::vector* v) const { + double vs[] = { + slog(float_mad), slog(float_add_sub), slog(float_mul), slog(float_div_mod), + slog(float_cmp), slog(float_math_func), slog(float_other_func), // + slog(int_mad), slog(int_add_sub), slog(int_mul), slog(int_div_mod), + slog(int_cmp), slog(int_math_func), slog(int_other_func), // + slog(bool_op), slog(select_op), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + }; + + /*! \brief Loop binding features */ + struct ForKindFeature { + enum class Pos : int { + kPosNone = 0, // Does not have this kind of annotation + kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator + kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator + kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator + kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator + kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator + kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator + kPosMixed = 7, // The annotated iterator is a mixed space and reduce iterator + kEnd = 8, + }; + int64_t num = 0; // The number of iterators with the annotation + int64_t prod = 0; // The product of the lengths of iterators with the annotation + int64_t len = 0; // The length of the innermost iterator with the annotation + Pos pos = Pos::kPosMixed; // The position of the iterators with the annotation + + static constexpr int64_t kCount = 11; + + explicit ForKindFeature(const ForVec& loops); + + void Export(std::vector* v) const { + double vs[] = { + slog(num), + slog(prod), + slog(len), + static_cast(static_cast(pos) == 0), + static_cast(static_cast(pos) == 1), + static_cast(static_cast(pos) == 2), + static_cast(static_cast(pos) == 3), + static_cast(static_cast(pos) == 4), + static_cast(static_cast(pos) == 5), + static_cast(static_cast(pos) == 6), + static_cast(static_cast(pos) == 7), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + }; + + ArithOps arith_ops; // Arithmetic features + ForKindFeature vectorize; // Loop binding features: kVectorize + ForKindFeature unroll; // Loop binding features: kUnroll + ForKindFeature parallel; // Loop binding features: kParallel + bool is_gpu = false; // If the program is running on GPU + int64_t blockIdx_x_len = 1; // The length of blockIdx.x + int64_t blockIdx_y_len = 1; // The length of blockIdx.y + int64_t blockIdx_z_len = 1; // The length of blockIdx.z + int64_t threadIdx_x_len = 1; // The length of threadIdx.x + int64_t threadIdx_y_len = 1; // The length of threadIdx.y + int64_t threadIdx_z_len = 1; // The length of threadIdx.z + int64_t vthread_len = 1; // The length of virtual thread + + static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount * 3 + 8; + + explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, bool is_gpu) + : arith_ops(store, loop_nest.prod), + vectorize(loop_nest.vectorize), + unroll(loop_nest.unroll), + parallel(loop_nest.parallel) { + if (is_gpu) { + this->is_gpu = true; + this->blockIdx_x_len = utils::FirstLoopExtent(loop_nest.blockIdx_x, 1); + this->blockIdx_y_len = utils::FirstLoopExtent(loop_nest.blockIdx_y, 1); + this->blockIdx_z_len = utils::FirstLoopExtent(loop_nest.blockIdx_z, 1); + this->threadIdx_x_len = utils::FirstLoopExtent(loop_nest.threadIdx_x, 1); + this->threadIdx_y_len = utils::FirstLoopExtent(loop_nest.threadIdx_y, 1); + this->threadIdx_z_len = utils::FirstLoopExtent(loop_nest.threadIdx_z, 1); + this->vthread_len = utils::FirstLoopExtent(loop_nest.vthread, 1); + } + } + + void Export(std::vector* v) const { + this->arith_ops.Export(v); + this->vectorize.Export(v); + this->unroll.Export(v); + this->parallel.Export(v); + double vs[] = { + static_cast(is_gpu), // + slog(blockIdx_x_len), slog(blockIdx_y_len), slog(blockIdx_z_len), + slog(threadIdx_x_len), slog(threadIdx_y_len), slog(threadIdx_z_len), + slog(vthread_len), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } +}; + +Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent) { + class ArithOpCounter : public ExprVisitor { + public: +#define TVM_FEATURE_SIMPLE(Type, Counter) \ + void VisitExpr_(const Type* op) final { \ + result_.Counter += this->prod_loop_extent_; \ + ExprVisitor::VisitExpr_(op); \ + } +#define TVM_FEATURE_BINARY(Type, FloatCounter, IntCounter) \ + void VisitExpr_(const Type* op) final { \ + if (op->dtype.is_float()) { \ + result_.FloatCounter += this->prod_loop_extent_; \ + } else { \ + result_.IntCounter += this->prod_loop_extent_; \ + } \ + ExprVisitor::VisitExpr_(op); \ + } + TVM_FEATURE_SIMPLE(AndNode, bool_op); + TVM_FEATURE_SIMPLE(OrNode, bool_op); + TVM_FEATURE_SIMPLE(NotNode, bool_op); + TVM_FEATURE_SIMPLE(SelectNode, select_op); + TVM_FEATURE_BINARY(AddNode, float_add_sub, int_add_sub); + TVM_FEATURE_BINARY(SubNode, float_add_sub, int_add_sub); + TVM_FEATURE_BINARY(MulNode, float_mul, int_mul); + TVM_FEATURE_BINARY(DivNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(ModNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(FloorDivNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(FloorModNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(MaxNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(MinNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(EQNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(NENode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(LTNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(LENode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(GTNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(GENode, float_cmp, int_cmp); +#undef TVM_FEATURE_BINARY +#undef TVM_FEATURE_SIMPLE + + void VisitExpr_(const CallNode* op) final { + static auto op_call_effect_ = Op::GetAttrMap("TCallEffectKind"); + TCallEffectKind effect_kind = op_call_effect_[Downcast(op->op)]; + bool is_pure = + effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; + if (is_pure) { + if (op->dtype.is_float()) { + result_.float_math_func += prod_loop_extent_; + } else { + result_.int_math_func += prod_loop_extent_; + } + } else { + if (op->dtype.is_float()) { + result_.float_other_func += prod_loop_extent_; + } else { + result_.int_other_func += prod_loop_extent_; + } + } + ExprVisitor::VisitExpr_(op); + } + + int64_t prod_loop_extent_; + ArithOps result_; + }; + ArithOpCounter counter; + counter.prod_loop_extent_ = prod_loop_extent; + counter(store->value); + *this = counter.result_; +} + +Feature::ForKindFeature::ForKindFeature(const ForVec& loops) { + if (loops.empty()) { + this->num = 0; + this->prod = 0; + this->len = 0; + this->pos = ForKindFeature::Pos::kPosNone; + } else { + const int64_t* last_loop_extent = GetLoopIntExtent(loops.back()); + this->num = loops.size(); + this->len = last_loop_extent ? *last_loop_extent : 1; + this->pos = ForKindFeature::Pos::kPosMixed; + int64_t& prod = this->prod = 1; + for (const ForNode* loop : loops) { + if (const int64_t* extent = GetLoopIntExtent(loop)) { + prod *= *extent; + } + } + } +} + +} // namespace group1 + +namespace group2 { + +/*! \brief Group 2 features */ +struct Feature { + enum class AccessType : int { + kRead = 0, // The buffer is read but not written + kWrite = 1, // The buffer is written but not read + kReadWrite = 2, // The buffer is both read and written + kUnknownRW = 3, // Unknown type + kEnd = 4, + }; + enum class ReuseType : int { + kLoopMultipleRead = 0, // Buffer reuse because accessed on each iteration of a loop + kSerialMultipleReadWrite = 1, // Buffer reuse because it is serially accessed + kNoReuse = 2, // No buffer reuse + kEnd = 3, + }; + + struct SubFeature { + // + const BufferNode* buffer = nullptr; + AccessType access_type = AccessType::kUnknownRW; + std::vector multi_indices = {}; + // + /*! \brief loop_accessed_numel[i][...] means the number of elements accessed by loops[i] */ + std::vector> loop_accessed_numel = {}; + IntVec access_shape; + int64_t num_continuous_bytes = 1; + // Stride information + int64_t min_stride = 0; + int64_t innermost_stride = 0; + int64_t prod_non_strided_loop_extent = 0; + // Reuse information + ReuseType reuse_type = ReuseType::kNoReuse; + double reuse_dis_iter = 0.0; + double reuse_dis_bytes = 0.0; + int64_t reuse_ct = 0; + // Features + double bytes; // The touched memory in bytes + double unique_bytes; // The touched unique memory in bytes + double lines; // The number of touched cache lines + double unique_lines; // The number touched unique cache lines + double bytes_d_reuse_ct; // bytes / reuse_ct + double unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct + double lines_d_reuse_ct; // lines / reuse_ct + double unique_lines_d_reuse_ct; // unique_lines / reuse_ct + double stride; // The stride in access + + static constexpr int64_t kCount = 18; + + void Export(std::vector* v) const { + double vs[] = { + static_cast(static_cast(access_type) == 0), + static_cast(static_cast(access_type) == 1), + static_cast(static_cast(access_type) == 2), + // FeatureSet::BufferAccess::AccessType::kUnknownRW is ignored + slog(bytes), + slog(unique_bytes), + slog(lines), + slog(unique_lines), + static_cast(static_cast(reuse_type) == 0), + static_cast(static_cast(reuse_type) == 1), + static_cast(static_cast(reuse_type) == 2), + slog(reuse_dis_iter), + slog(reuse_dis_bytes), + slog(reuse_ct), + slog(bytes_d_reuse_ct), + slog(unique_bytes_d_reuse_ct), + slog(lines_d_reuse_ct), + slog(unique_lines_d_reuse_ct), + slog(stride), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + static void Pad(std::vector* v) { v->insert(v->end(), 18, 0.0); } + + void SetStride(const LoopNest& loop_nest); + + void SetReuse(const LoopNest& loop_nest, // + int64_t top_loop_touch_bytes, // + const ForBufferMap& buffer_touched_under_loop); + + void SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes); + + explicit SubFeature(const BufferNode* buffer, AccessType access_type, + std::vector multi_indices, int n_loops) + : buffer(buffer), + access_type(access_type), + multi_indices(multi_indices), + loop_accessed_numel(n_loops) {} + }; + + void Export(std::vector* v, int buffers_per_store) const { + int n = sub_features.size(); + for (int i = 0; i < buffers_per_store; ++i) { + if (i < n) { + sub_features[i].Export(v); + } else { + SubFeature::Pad(v); + } + } + } + + explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, + int64_t cache_line_bytes, IntVec* for_touched_bytes, + ForBufferMap* buffer_touched_under_loop, arith::Analyzer* analyzer); + + void Init(const BufferStoreNode* store, int n_loops); + + void SetRegion(const LoopNest& loop_nest, // + IntVec* for_touched_bytes, // + ForBufferMap* buffer_touched_under_loop, // + arith::Analyzer* analyzer); + + std::vector sub_features; +}; + +void Feature::Init(const BufferStoreNode* store, int n_loops) { + struct Info { + AccessType access_type = AccessType::kUnknownRW; + std::vector multi_indices; + }; + std::unordered_map buffer_info; + { + Info& info = buffer_info[store->buffer.get()]; + info.access_type = AccessType::kWrite; + info.multi_indices.push_back({store->indices.begin(), store->indices.end()}); + } + PostOrderVisit(store->value, [&buffer_info](const ObjectRef& obj) -> void { + if (const BufferLoadNode* load = obj.as()) { + const BufferNode* buffer = load->buffer.get(); + Info& info = buffer_info[buffer]; + switch (info.access_type) { + case AccessType::kRead: + break; + case AccessType::kWrite: + info.access_type = AccessType::kReadWrite; + break; + case AccessType::kReadWrite: + break; + case AccessType::kUnknownRW: + default: + info.access_type = AccessType::kRead; + break; + } + if (info.access_type != AccessType::kReadWrite) { + info.multi_indices.push_back({load->indices.begin(), load->indices.end()}); + } + } + }); + this->sub_features.reserve(buffer_info.size()); + for (const auto& kv : buffer_info) { + this->sub_features.emplace_back(kv.first, kv.second.access_type, + std::move(kv.second.multi_indices), n_loops); + } +} + +void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes, + ForBufferMap* buffer_touched_under_loop, + arith::Analyzer* analyzer) { + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + // Step 1. Initialize and bind all the loop variables to a constant + *for_touched_bytes = IntVec(n_loops, 0); + for (int i = 0; i < n_loops; ++i) { + const ForNode* loop = loops[i]; + analyzer->Bind(loop->loop_var, loop->min, /*allow_override=*/true); + } + // Step 2. Corner case: no loops + if (n_loops == 0) { + // In this case, the `access_shape` is not calculated + for (SubFeature& feature : sub_features) { + feature.access_shape = IntVec(feature.buffer->shape.size(), 1); + } + return; + } + // Step 3. Gradually bind the loops from inner to outer, + // calculate the area the loops touch on each buffer + for (int i = n_loops - 1; i >= 0; --i) { + const ForNode* loop = loops[i]; + analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent), + /*allow_override=*/true); + int64_t& touched_bytes = (*for_touched_bytes)[i] = 0; + for (SubFeature& feature : sub_features) { + const BufferNode* buffer = feature.buffer; + // Note: `feature.access_shape` for `i == 0` is the only one preserved, + // while others are discarded + int64_t numel = 1; + feature.access_shape = utils::RelaxAndUnion(feature.multi_indices, &numel, analyzer); + feature.loop_accessed_numel[i][buffer] = numel; + touched_bytes += numel * buffer->dtype.bytes(); + (*buffer_touched_under_loop)[loop][buffer].push_back(numel); + } + } +} + +void Feature::SubFeature::SetStride(const LoopNest& loop_nest) { + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + // For each buffer, we find the loop stride on it + const BufferNode* buffer = this->buffer; + int ndim = this->buffer->shape.size(); + IntVec buffer_shape = support::AsVector(buffer->shape); + // Calculate the buffer's stride from its shape + IntVec buffer_stride(ndim); + if (ndim >= 1) { + buffer_stride[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; --i) { + buffer_stride[i] = buffer_stride[i + 1] * buffer_shape[i + 1]; + } + } + // Calculate `num_continuous_bytes` + { + int64_t& num_continuous_bytes = this->num_continuous_bytes = 1; + const IntVec& access_shape = this->access_shape; + ICHECK_EQ(access_shape.size(), buffer_shape.size()); + for (int i = ndim - 1; i >= 0; --i) { + if (access_shape[i] == buffer_shape[i]) { + // TODO + num_continuous_bytes = buffer_shape[i] * buffer->dtype.bytes(); + break; + } + } + } + // Enumerate loops from inner to outer + int i = 0; + // Calculate this->min_stride + int64_t& stride = this->min_stride = 0; + for (i = n_loops - 1; i >= 0; --i) { + stride = utils::GetVarStride(this->multi_indices, buffer_stride, loops[i]->loop_var); + if (stride != 0) { + break; + } + } + // Calculate this->innermost_stride + this->innermost_stride = (i == n_loops - 1) ? stride : 0; + // Calculate this->prod + int64_t& prod = this->prod_non_strided_loop_extent = 1; + for (int j = n_loops - 1; j > i; --j) { + if (const int64_t* extent = GetLoopIntExtent(loops[n_loops - 1])) { // TODO + prod *= *extent; + } + } +} + +void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t top_loop_touch_bytes, + const ForBufferMap& buffer_touched_under_loop) { + const BufferNode* buffer = this->buffer; + // Step 0. Collect all `Var`s that appears in the buffer region + std::unordered_set region_vars; + for (const MultiIndex& multi_index : this->multi_indices) { + for (const PrimExpr& index : multi_index) { + PostOrderVisit(index, [®ion_vars](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + region_vars.insert(var); + } + }); + } + } + // Default case: no reuse + ReuseType& reuse_type = this->reuse_type = ReuseType::kNoReuse; + double& reuse_dis_iter = this->reuse_dis_iter = 0; + double& reuse_dis_bytes = this->reuse_dis_bytes = 0; + int64_t& reuse_ct = this->reuse_ct = 0; + + // Step 3.2. Enumerate loops from inner to outer, find the first loop with reuse + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + for (int i = n_loops - 1; i >= 0; --i) { + const ForNode* loop = loops[i]; + // Case 1. Find an invariant loop, i.e. reuse with kLoopMultipleRead + if (!region_vars.count(loop->loop_var.get())) { + reuse_type = ReuseType::kLoopMultipleRead; + if (const int64_t* extent = GetLoopIntExtent(loop)) { + reuse_ct = *extent; + } else { + reuse_ct = 1; + } + reuse_dis_iter = 1; + for (int j = n_loops - 1; j > i; --j) { + if (const int64_t* extent = GetLoopIntExtent(loops[j])) { + reuse_dis_iter *= *extent; + } + } + reuse_dis_bytes = 0.0; + if (i == n_loops - 1) { + reuse_dis_bytes = top_loop_touch_bytes; + } else { + for (const auto& iter : buffer_touched_under_loop.at(loops[i + 1])) { + const BufferNode* buffer = iter.first; + const IntVec& numels = iter.second; + int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0)); + reuse_dis_bytes += numel * buffer->dtype.bytes(); + } + } + break; + } + // Case 2. Find serial reuse, i.e. reuse with kSerialMultipleReadWrite + const IntVec& touched = buffer_touched_under_loop.at(loop).at(buffer); + if (touched.size() >= 2) { + int64_t extent = 1; + if (const int64_t* ext = GetLoopIntExtent(loop)) { + extent = *ext; + } + reuse_type = ReuseType::kSerialMultipleReadWrite; + reuse_ct = touched.size() - 1; + reuse_dis_iter = *std::min_element(touched.begin(), touched.end()); + reuse_dis_bytes = 0.0; + for (const auto& iter : buffer_touched_under_loop.at(loop)) { + const BufferNode* buffer = iter.first; + const IntVec& numels = iter.second; + int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0)); + reuse_dis_bytes += numel * buffer->dtype.bytes(); + } + reuse_dis_iter /= extent; + reuse_dis_bytes /= extent; + break; + } + } +} + +void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes) { + int64_t dtype_bytes = this->buffer->dtype.bytes(); + this->stride = this->innermost_stride; + this->bytes = dtype_bytes * loop_nest.prod; + if (loop_nest.loops.empty()) { + this->unique_bytes = 1; + this->lines = 1; + this->unique_lines = 1; + } else { + this->unique_bytes = this->loop_accessed_numel.front().at(buffer) * dtype_bytes; + this->lines = static_cast(loop_nest.prod) / this->prod_non_strided_loop_extent * + std::min(1.0, 1.0 * this->min_stride * dtype_bytes / cache_line_bytes); + this->lines = std::max(1.0, this->lines); + this->unique_lines = static_cast(this->unique_bytes) / + std::min(cache_line_bytes, this->num_continuous_bytes); + this->unique_lines = std::max(1.0, this->unique_lines); + } + double proxy_reuse_ct = this->reuse_ct > 0 ? this->reuse_ct : 0.5; + this->bytes_d_reuse_ct = this->bytes / proxy_reuse_ct; + this->unique_bytes_d_reuse_ct = this->unique_bytes / proxy_reuse_ct; + this->lines_d_reuse_ct = this->lines / proxy_reuse_ct; + this->unique_lines_d_reuse_ct = this->unique_lines / proxy_reuse_ct; +} + +Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_t cache_line_bytes, + IntVec* for_touched_bytes, ForBufferMap* buffer_touched_under_loop, + arith::Analyzer* analyzer) { + int n_loops = loop_nest.loops.size(); + // Step 0. Initialize data structures + this->Init(store, n_loops); + // Step 1. Calculate region-related feature + this->SetRegion(loop_nest, for_touched_bytes, buffer_touched_under_loop, analyzer); + // Step 2. Calculate stride-related feature + for (auto& feature : sub_features) { + feature.SetStride(loop_nest); + } + // Step 3. Calculate reuse-related feature + int64_t top_loop_touch_bytes = 0.0; + if (n_loops > 0) { + for (const SubFeature& feature : sub_features) { + int64_t bytes = feature.buffer->dtype.bytes(); + int64_t n_buffer = feature.loop_accessed_numel[0].size(); + top_loop_touch_bytes += bytes * n_buffer; + } + } + for (auto& feature : sub_features) { + feature.SetReuse(loop_nest, top_loop_touch_bytes, *buffer_touched_under_loop); + } + // Step 4. Calculate rest of the features + for (auto& feature : sub_features) { + feature.SetFeature(loop_nest, cache_line_bytes); + } + // Step 5. Sort the features + std::sort(sub_features.begin(), sub_features.end(), [](const SubFeature& a, const SubFeature& b) { + if (a.lines != b.lines) { + return a.lines > b.lines; + } + if (a.bytes != b.bytes) { + return a.bytes > b.bytes; + } + return a.buffer->name < b.buffer->name; + }); +} + +} // namespace group2 + +namespace group3 { + +/*! \brief Group 3 feature */ +struct Feature { + std::vector arith_intensity_curve; + + void Export(std::vector* v) const { + v->insert(v->end(), arith_intensity_curve.begin(), arith_intensity_curve.end()); + } + + explicit Feature(int n_samples, const LoopNest& loop_nest, const IntVec& for_touched_bytes, + const group1::Feature::ArithOps& arith_ops) + : arith_intensity_curve(n_samples, 0.0) { + const std::vector& loops = loop_nest.loops; + ICHECK_EQ(loops.size(), for_touched_bytes.size()); + int n_loops = loops.size(); + // Calculate `memory_bytes` + std::vector memory_bytes; + memory_bytes.resize(n_loops); + for (int i = 0; i < n_loops; ++i) { + memory_bytes[n_loops - 1 - i] = std::log2(for_touched_bytes[i]); + } + // Calculate `compute_ops` and `cur_compute_ops` + std::vector compute_ops; + double total_compute_ops = arith_ops.float_mad + arith_ops.float_add_sub + arith_ops.float_mul + + arith_ops.float_div_mod + arith_ops.float_cmp + + arith_ops.float_math_func + arith_ops.float_other_func; + total_compute_ops /= loop_nest.prod; + for (int i = n_loops - 1; i >= 0; --i) { + if (const int64_t* extent = GetLoopIntExtent(loops[i])) { + total_compute_ops *= *extent; + } + compute_ops.push_back(std::log2(total_compute_ops)); + } + // Fill the feature set + if (total_compute_ops <= 0 || compute_ops.empty()) { + for (int i = 0; i < n_samples; ++i) { + arith_intensity_curve[i] = 0.0; + } + return; + } + total_compute_ops = compute_ops.back(); // i.e. total_compute_ops = log2(total_compute_ops) + int p = 0; + for (int i = 0; i < n_samples; ++i) { + double& result = arith_intensity_curve[i]; + double cur_compute_ops = static_cast(i + 1) / n_samples * total_compute_ops; + // Find the first `p` that `compute[p] >= total * (i + 1) / N` + for (; p < n_loops; ++p) { + if (compute_ops[p] >= cur_compute_ops - 1e-4) { + break; + } + } + CHECK_LT(p, n_loops); + if (p == 0) { + result = compute_ops[p] / memory_bytes[p]; + } else { + double base = compute_ops[p - 1] / memory_bytes[p - 1]; + double slope = + (compute_ops[p] / memory_bytes[p] - compute_ops[p - 1] / memory_bytes[p - 1]) / + (compute_ops[p] - compute_ops[p - 1]); + result = base + slope * (cur_compute_ops - compute_ops[p - 1]); + } + } + } +}; + +} // namespace group3 + +namespace group4 { + +/*! \brief Group 4 feature */ +struct Feature { + int64_t alloc_size = 0; // The size of allocated buffer in bytes + int64_t alloc_prod = 0; // alloc_outer_prod * alloc_inner_prod + int64_t alloc_outer_prod = 1; // The product of lengths of loops outside the scope of the alloc + + static constexpr int64_t kCount = 4; + + void Export(std::vector* v, int64_t outer_prod) const { + double vs[] = { + slog(alloc_size), + slog(alloc_prod), + slog(alloc_outer_prod), + slog(static_cast(outer_prod) / alloc_outer_prod), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + Feature() = default; + + explicit Feature(const LoopNest& loop_nest, const Buffer& buffer) { + int64_t numel = 1; + for (int64_t x : support::AsVector(buffer->shape)) { + numel *= x; + } + alloc_size = numel * buffer->dtype.bytes(); + alloc_prod = numel * loop_nest.prod; + alloc_outer_prod = loop_nest.prod; + } +}; + +} // namespace group4 + +namespace group5 { + +/*! \brief Group 5 feature */ +struct Feature { + int64_t outer_prod; // The product of lengths of outer loops + int num_loops; // The number of outer loops + int auto_unroll_max_step; // The value of pragma "auto_unroll_max_step" + + static constexpr int64_t kCount = 3; + + void Export(std::vector* v) const { + double vs[] = { + slog(outer_prod), + slog(num_loops), + slog(auto_unroll_max_step), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + explicit Feature(const LoopNest& loop_nest) { + this->outer_prod = loop_nest.prod; + this->num_loops = loop_nest.loops.size(); + this->auto_unroll_max_step = loop_nest.auto_unroll.empty() ? 0 : loop_nest.auto_unroll.back(); + } +}; + +} // namespace group5 + +/*! \brief The feature extracted */ +struct Feature { + const BufferNode* buffer = nullptr; + int buffer_order = -1; + std::unique_ptr group1 = nullptr; + std::unique_ptr group2 = nullptr; + std::unique_ptr group3 = nullptr; + std::unique_ptr group4 = nullptr; + std::unique_ptr group5 = nullptr; + + bool operator<(const Feature& other) const { return buffer_order < other.buffer_order; } +}; + +/*! \brief The main feature extractor */ +class PerStoreFeatureCollector : private StmtVisitor { + public: + static std::vector Collect(bool is_gpu, int64_t cache_line_bytes, + int64_t arith_intensity_curve_num_samples, + const IRModule& mod) { + PerStoreFeatureCollector collector(is_gpu, cache_line_bytes, arith_intensity_curve_num_samples); + for (const auto& kv : mod->functions) { + if (const PrimFuncNode* func = kv.second.as()) { + collector(func->body); + for (const auto& it : func->buffer_map) { + collector.HandleBufferAlloc(it.second); + } + } + } + std::vector result; + result.reserve(collector.buffer_features_.size()); + for (auto& it : collector.buffer_features_) { + Feature& feature = it.second; + if (feature.buffer != nullptr) { + ICHECK(feature.group1); + ICHECK(feature.group2); + ICHECK(feature.group3); + ICHECK(feature.group5); + if (feature.group4 == nullptr) { + feature.group4 = std::make_unique(); + } + result.push_back(std::move(feature)); + } + } + std::sort(result.begin(), result.end()); + return result; + } + + private: + void VisitStmt_(const ForNode* loop) final { + int64_t auto_unroll; + ForVec* for_vec = loop_nest_.Push(loop, &auto_unroll); + StmtVisitor::VisitStmt_(loop); + loop_nest_.Pop(loop, for_vec, auto_unroll); + } + + void VisitStmt_(const BufferStoreNode* store) final { + if (store->value->IsInstance() || store->value->IsInstance()) { + return; + } + const BufferNode* buffer = store->buffer.get(); + Feature& feature = buffer_features_[buffer]; + if (feature.buffer == nullptr) { + feature.buffer = buffer; + feature.buffer_order = buffer_features_.size(); + } + feature.group1 = std::make_unique(store, loop_nest_, is_gpu_); + feature.group2 = + std::make_unique(store, loop_nest_, cache_line_bytes_, &for_touched_bytes_, + &buffer_touched_under_loop_, &analyzer_); + feature.group3 = + std::make_unique(arith_intensity_curve_num_samples_, loop_nest_, + for_touched_bytes_, feature.group1->arith_ops); + feature.group5 = std::make_unique(loop_nest_); + } + + void VisitStmt_(const BlockNode* block) final { + StmtVisitor::VisitStmt_(block); + for (const Buffer& buffer : block->alloc_buffers) { + HandleBufferAlloc(buffer); + } + } + + void HandleBufferAlloc(const Buffer& buffer) { + Feature& feature = buffer_features_[buffer.get()]; + feature.group4 = std::make_unique(loop_nest_, buffer); + } + + explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes, + int64_t arith_intensity_curve_num_samples) + : is_gpu_(is_gpu), + cache_line_bytes_(cache_line_bytes), + arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples) {} + + bool is_gpu_; + int64_t cache_line_bytes_; + int64_t arith_intensity_curve_num_samples_; + arith::Analyzer analyzer_; + LoopNest loop_nest_ = {}; + IntVec for_touched_bytes_ = {}; + ForBufferMap buffer_touched_under_loop_ = {}; + std::unordered_map buffer_features_ = {}; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +class PerStoreFeatureNode : public FeatureExtractorNode { + public: + int buffers_per_store; + int arith_intensity_curve_num_samples; + int cache_line_bytes; + int feature_vector_length; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("buffers_per_store", &buffers_per_store); + v->Visit("arith_intensity_curve_num_samples", &arith_intensity_curve_num_samples); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("feature_vector_length", &feature_vector_length); + } + + void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { + static transform::Sequential passes = tir::transform::PassListForPerStoreFeature(); + mod = passes(std::move(mod)); + std::vector features = tir::PerStoreFeatureCollector::Collect( + is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod); + int n_features = features.size(); + results->resize(n_features); + for (int i = 0; i < n_features; ++i) { + const tir::Feature& feature = features[i]; + std::vector& result = (*results)[i]; + result.reserve(feature_vector_length); + feature.group1->Export(&result); + feature.group2->Export(&result, this->buffers_per_store); + feature.group3->Export(&result); + feature.group4->Export(&result, feature.group5->outer_prod); + feature.group5->Export(&result); + ICHECK_EQ(static_cast(result.size()), feature_vector_length); + } + } + + Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) { + bool is_gpu = tune_context->target.value()->kind->name == "cuda"; + std::vector results; + results.resize(candidates.size()); + auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void { + const auto& candidate = candidates[task_id]; + std::vector> features; + ExtractSingle(candidate->sch->mod(), is_gpu, &features); + results[task_id] = tir::utils::AsNDArray(features); + }; + support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f); + return results; + } + + static constexpr const char* _type_key = "meta_schedule.PerStoreFeature"; + TVM_DECLARE_FINAL_OBJECT_INFO(PerStoreFeatureNode, FeatureExtractorNode); +}; + +FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, + int arith_intensity_curve_num_samples, + int cache_line_bytes) { + ObjectPtr n = make_object(); + n->buffers_per_store = buffers_per_store; + n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples; + n->cache_line_bytes = cache_line_bytes; + n->feature_vector_length = tir::group1::Feature::kCount + // + tir::group2::Feature::SubFeature::kCount * buffers_per_store + // + arith_intensity_curve_num_samples + // + tir::group4::Feature::kCount + // + tir::group5::Feature::kCount; + return FeatureExtractor(n); +} + +TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") + .set_body_typed(FeatureExtractor::PerStoreFeature); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc new file mode 100644 index 0000000000..b29405333d --- /dev/null +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class AddToDatabaseNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + TuneContext task = task_scheduler->tasks[task_id]; + Database database = task_scheduler->database; + Workload workload = database->CommitWorkload(task->mod.value()); + Target target = task->target.value(); + ICHECK_EQ(runner_results.size(), measure_candidates.size()); + int n = runner_results.size(); + for (int i = 0; i < n; ++i) { + RunnerResult result = runner_results[i]; + MeasureCandidate candidate = measure_candidates[i]; + if (result->error_msg.defined()) { + continue; + } + database->CommitTuningRecord(TuningRecord( + /*trace=*/candidate->sch->trace().value(), + /*run_secs=*/result->run_secs.value(), + /*workload=*/workload, + /*target=*/target, + /*args_info=*/candidate->args_info)); + } + } + + static constexpr const char* _type_key = "meta_schedule.AddToDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(AddToDatabaseNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::AddToDatabase() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(AddToDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase") + .set_body_typed(MeasureCallback::AddToDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc b/src/meta_schedule/measure_callback/echo_statistics.cc new file mode 100644 index 0000000000..435d65afc2 --- /dev/null +++ b/src/meta_schedule/measure_callback/echo_statistics.cc @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +double CountFlop(const IRModule& mod) { + struct TResult { + using TTable = std::unordered_map; + + TResult() = default; + + explicit TResult(const tvm::DataType& dtype) { Add(dtype); } + + void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } + + TResult operator+=(const TResult& rhs) { + for (const auto& kv : rhs.data_) { + data_[kv.first] += kv.second; + } + return *this; + } + + TResult operator*=(int64_t rhs) { + for (auto& kv : data_) { + kv.second *= rhs; + } + return *this; + } + + TResult MaxWith(const TResult& rhs) { + for (const auto& kv : rhs.data_) { + double& v = data_[kv.first]; + if (v < kv.second) { + v = kv.second; + } + } + return *this; + } + + struct DType { + uint8_t code : 8; + uint8_t bits : 8; + uint16_t lanes : 16; + }; + static_assert(sizeof(DType) == 4, "Incorrect size of DType"); + + static String Int2Str(int32_t dtype) { + union { + DType dst; + int32_t src; + } converter; + converter.src = dtype; + static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"}; + std::ostringstream os; + os << type_code_tab[converter.dst.code]; + os << static_cast(converter.dst.bits); + if (converter.dst.lanes != 1) { + os << "x" << static_cast(converter.dst.lanes); + } + return os.str(); + } + + static int32_t DataType2Int(const tvm::DataType& dtype) { + union { + DType src; + int32_t dst; + } converter; + converter.src.code = dtype.code(); + converter.src.bits = dtype.bits(); + converter.src.lanes = dtype.lanes(); + return converter.dst; + } + + TTable data_; + }; + + class FlopCounter : public ExprFunctor, + public StmtFunctor { + public: + ~FlopCounter() {} + + TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } + TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } + + TResult VisitStmt_(const IfThenElseNode* branch) override { + TResult cond = VisitExpr(branch->condition); + cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case)); + return cond; + } + + TResult VisitStmt_(const BufferStoreNode* store) override { + TResult result = VisitExpr(store->value); + for (const PrimExpr& e : store->indices) { + result += VisitExpr(e); + } + return result; + } + + TResult VisitStmt_(const SeqStmtNode* seq) override { + TResult result; + for (const Stmt& stmt : seq->seq) { + result += VisitStmt(stmt); + } + return result; + } + + TResult VisitStmt_(const BlockRealizeNode* block) override { + return VisitStmt(block->block->body); + } + + TResult VisitStmt_(const BlockNode* block) override { + TResult result; + if (block->init.defined()) { + result += VisitStmt(block->init.value()); + } + result += VisitStmt(block->body); + return result; + } + + TResult VisitStmt_(const ForNode* loop) override { + TResult result = VisitStmt(loop->body); + const auto* int_imm = loop->extent.as(); + ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " + << loop->extent->GetTypeKey(); + result *= int_imm->value; + return result; + } + +#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \ + TResult VisitExpr_(const Node* op) final { \ + TResult result(op->dtype); \ + result += VisitExpr(op->a); \ + result += VisitExpr(op->b); \ + return result; \ + } + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode); + TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode); +#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY + TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } + TResult VisitExpr_(const VarNode* op) override { return TResult(); } + TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); } + TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); } + TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } + TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } + TResult VisitExpr_(const NotNode* op) override { + TResult result(op->dtype); + result += VisitExpr(op->a); + return result; + } + TResult VisitExpr_(const SelectNode* op) override { + TResult cond = VisitExpr(op->condition); + cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value)); + return cond; + } + TResult VisitExpr_(const CallNode* op) override { + TResult ret; + for (const auto& x : op->args) { + ret += VisitExpr(x); + } + return ret; + } + }; + FlopCounter counter; + TResult result; + for (const auto& kv : mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + result += counter.VisitStmt(prim_func->body); + } + } + double cnt = 0.0; + int i32 = TResult::DataType2Int(tvm::DataType::Int(32)); + int i64 = TResult::DataType2Int(tvm::DataType::Int(64)); + int u1 = TResult::DataType2Int(tvm::DataType::UInt(1)); + for (const auto& kv : result.data_) { + if (kv.first != i32 && kv.first != i64 && kv.first != u1) { + cnt += kv.second; + } + } + return cnt; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +constexpr const double kMaxTime = 1e10; + +std::string GetTaskName(const TuneContext& task, int task_id) { + std::ostringstream os; + os << '#' << task_id << ": " << task->task_name; + return os.str(); +} + +double GetRunMs(const Array& run_secs) { + double total = 0.0; + for (const FloatImm& i : run_secs) { + total += i->value; + } + return total * 1e3 / run_secs.size(); +} + +struct TaskInfo { + std::string name; + double flop = 0.0; + int trials = 0; + int best_round = -1; + double best_ms = kMaxTime; + double best_gflops = 0.0; + int error_count = 0; + + explicit TaskInfo(const String& name) : name(name) {} + + void Update(double run_ms) { + ++trials; + if (run_ms < best_ms) { + best_ms = run_ms; + best_round = trials; + best_gflops = flop / run_ms / 1e6; + } + LOG(INFO) << "[" << name // + << "] Trial #" << trials // + << ": GFLOPs: " << (flop / run_ms / 1e6) // + << ". Best GFLOPs: " << best_gflops; + } + + void UpdateError(const String& err, const MeasureCandidate& candidate) { + ++error_count; + LOG(INFO) << "[" << name // + << "] Trial #" << trials // + << ": Error in building: " << err << "\n" + << tir::AsTVMScript(candidate->sch->mod()) << "\n" + << Concat(candidate->sch->trace().value()->AsPython(false), "\n"); + } +}; + +class EchoStatisticsNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + if (this->task_info.empty()) { + SetupTaskInfo(task_scheduler->tasks); + } + ICHECK_EQ(measure_candidates.size(), builder_results.size()); + ICHECK_EQ(measure_candidates.size(), runner_results.size()); + int n = measure_candidates.size(); + TuneContext task = task_scheduler->tasks[task_id]; + TaskInfo& info = this->task_info[task_id]; + std::string task_name = GetTaskName(task, task_id); + for (int i = 0; i < n; ++i) { + MeasureCandidate candidate = measure_candidates[i]; + BuilderResult builder_result = builder_results[i]; + RunnerResult runner_result = runner_results[i]; + if (Optional err = builder_result->error_msg) { + info.UpdateError(err.value(), candidate); + } else if (Optional err = runner_result->error_msg) { + info.UpdateError(err.value(), candidate); + } else { + ICHECK(runner_result->run_secs.defined()); + info.Update(GetRunMs(runner_result->run_secs.value())); + } + } + } + + void SetupTaskInfo(const Array& tasks) { + task_info.reserve(tasks.size()); + int task_id = 0; + for (const TuneContext& task : tasks) { + task_info.push_back(TaskInfo(GetTaskName(task, task_id))); + TaskInfo& info = task_info.back(); + info.flop = tir::CountFlop(task->mod.value()); + ++task_id; + } + } + + std::vector task_info; + + static constexpr const char* _type_key = "meta_schedule.EchoStatistics"; + TVM_DECLARE_FINAL_OBJECT_INFO(EchoStatisticsNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::EchoStatistics() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(EchoStatisticsNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackEchoStatistics") + .set_body_typed(MeasureCallback::EchoStatistics); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc new file mode 100644 index 0000000000..733d118c73 --- /dev/null +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return MeasureCallback(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") + .set_body_method(&MeasureCallbackNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") + .set_body_typed(MeasureCallback::PyMeasureCallback); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc new file mode 100644 index 0000000000..649636def1 --- /dev/null +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class RemoveBuildArtifactNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + static const PackedFunc* f_rm = runtime::Registry::Get("meta_schedule.remove_build_dir"); + for (const BuilderResult& build_result : builder_results) { + if (Optional path = build_result->artifact_path) { + (*f_rm)(path.value()); + } + } + } + + static constexpr const char* _type_key = "meta_schedule.RemoveBuildArtifact"; + TVM_DECLARE_FINAL_OBJECT_INFO(RemoveBuildArtifactNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::RemoveBuildArtifact() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(RemoveBuildArtifactNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact") + .set_body_typed(MeasureCallback::RemoveBuildArtifact); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc new file mode 100644 index 0000000000..58c86abadf --- /dev/null +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class UpdateCostModelNode : public MeasureCallbackNode { + public: + void Apply(const TaskScheduler& task_scheduler, int task_id, + const Array& measure_candidates, + const Array& builder_results, + const Array& runner_results) final { + TuneContext task = task_scheduler->tasks[task_id]; + ICHECK(task_scheduler->cost_model.defined()) // + << "Cost model must be defined for the task scheduler!"; + ICHECK(task->measure_candidates.defined()) // + << "Task's measure candidates must be present!"; + CostModel cost_model = task_scheduler->cost_model.value(); + cost_model->Update(task, task->measure_candidates.value(), runner_results); + } + + static constexpr const char* _type_key = "meta_schedule.UpdateCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(UpdateCostModelNode, MeasureCallbackNode); +}; + +MeasureCallback MeasureCallback::UpdateCostModel() { + ObjectPtr n = make_object(); + return MeasureCallback(n); +} + +TVM_REGISTER_NODE_TYPE(UpdateCostModelNode); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel") + .set_body_typed(MeasureCallback::UpdateCostModel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc new file mode 100644 index 0000000000..7c973879f2 --- /dev/null +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if the instruction is annotation with `meta_schedule_parallel` + * \param inst The instruction to be checked + * \return Whether the instruction is annotation with `meta_schedule_parallel` + */ +bool IsAnnotateWithParallel(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_parallel; +} + +/*! + * \brief Replace the annotation value + * \param inst The instruction to be replaced + * \param ann_val The new annotation value + * \return The replaced instruction + */ +Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { + ICHECK_EQ(inst->inputs.size(), 2); + return Instruction(/*kind=*/inst->kind, // + /*inputs=*/{inst->inputs[0], Integer(ann_val)}, // + /*attrs=*/inst->attrs, + /*outputs=*/inst->outputs); +} + +/*! + * \brief Get the output of the instruction Get-Block + * \param inst The instruction to be checked + * \return The output of the instruction Get-Block + */ +const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { + static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock"); + if (!inst->kind.same_as(inst_get_block)) { + return nullptr; + } + ICHECK_EQ(inst->outputs.size(), 1); + const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode); + return block; +} + +/*! + * \brief Analyze the parallel structure + * \param self The schedule state + * \param block_name The name of the root block + * \param func_name The name of the PrimFunc + * \param limit The uplimit of the parallelism + * \return The parallel structure + */ +std::vector> AnalyzeParallel(const ScheduleState& self, + const String& block_name, const String& func_name, + int64_t limit) { + Array block_srefs = tir::GetBlocks(self, block_name, func_name); + ICHECK_EQ(block_srefs.size(), 1); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); + ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); + std::vector> results; + results.reserve(info.realizes.size()); + for (const BlockRealize& realize : info.realizes) { + // Step 1. Extract static loop extents for spatial loops + std::vector loop_extents; + const ForNode* loop = nullptr; + for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent; + (loop = loop_sref->StmtAs()) != nullptr; // + loop_sref = loop_sref->parent) { + int64_t loop_extent = -1; + if (const auto* ext = GetLoopIntExtent(loop)) { + if (!info.non_spatial_vars.count(loop->loop_var.get())) { + loop_extent = *ext; + } + } + if (loop_extent != -1) { + loop_extents.push_back(loop_extent); + } else { + loop_extents.clear(); + } + } + // Step 2. Take the prefix product of loop extents + if (!loop_extents.empty()) { + results.emplace_back(); + std::vector& result = results.back(); + result.reserve(loop_extents.size()); + int64_t prod_extent = 1; + for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) { + result.push_back(prod_extent *= *it); + if (prod_extent >= limit) { + break; + } + } + } + } + return results; +} + +/*! + * \brief Get the number of parallelizable loops for each subtree + * \param loop_extent_prods The parallel structure for each subtree + * \param limit The uplimit of the parallelism + * \return The number of parallelizable loops for each subtree + */ +std::vector GetNumFusedLoops(const std::vector>& loop_extent_prods, + int64_t limit) { + std::vector results; + results.reserve(loop_extent_prods.size()); + for (const std::vector& prods : loop_extent_prods) { + int n = prods.size(); + int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin(); + if (i > 0 && prods[i - 1] == limit) { + --i; + } + if (i != n) { + ++i; + } + results.push_back(i); + } + return results; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates the parallel extent */ +class MutateParallelNode : public MutatorNode { + public: + /*! + * \brief The maximum number of jobs to be launched per CPU core. + * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int64_t max_jobs_per_core; + /*! \brief The number of cores in CPU. */ + int max_parallel_extent_; + /*! \brief JSON representation of the workload */ + std::string json_mod_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + // `max_parallel_extent_` is not visited. + // `json_mod` is not visited. + } + + static constexpr const char* _type_key = "meta_schedule.MutateParallel"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final { + Target target = context->target.value(); + this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core; + this->json_mod_ = SaveJSON(context->mod.value()); + } + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief The candidate to be mutated */ +struct MutateParallelNode::Candidate { + /*! \brief The annotation instruction */ + Instruction inst; + /*! \brief The current parallel extent */ + int64_t parallel_extent; + /*! \brief The name of the root block */ + String block_name; + /*! \brief The name of the PrimFunc */ + String func_name; +}; + +/*! + * \brief Get an instruction that annotates the maximum parallel extent + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidate The candidate to be mutated + * \return Whether a decision is found + */ +bool FindParallelDecision(const Trace& trace, TRandState* rand_state, + MutateParallelNode::Candidate* candidate) { + using tir::BlockRVNode; + using tir::InstructionNode; + std::unordered_map get_block_insts; + std::vector ann_insts; + get_block_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (tir::IsAnnotateWithParallel(inst)) { + ann_insts.push_back(inst.get()); + } + if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) { + get_block_insts[block_rv] = inst.get(); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const InstructionNode* get_block_inst = + get_block_insts.at(Downcast(ann_inst->inputs[0]).get()); + ICHECK_EQ(get_block_inst->attrs.size(), 2); + candidate->inst = GetRef(ann_inst); + candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; + candidate->block_name = Downcast(get_block_inst->attrs[0]); + candidate->func_name = Downcast(get_block_inst->attrs[1]); + return true; +} + +Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { + // Step 1. Find a parallel decision. + Candidate candidate; + if (!FindParallelDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + // Step 2. Replay the instructions to recover loop extents + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/Downcast(LoadJSON(this->json_mod_)), // + /*rand_state=*/ForkSeed(rand_state), // + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + // Step 3. Find all possible parallel plans + std::vector> loop_extent_prods = tir::AnalyzeParallel( + sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_); + std::unordered_map> limit2plan; + std::map, int64_t> plan2limit; + for (const std::vector& prods : loop_extent_prods) { + for (int64_t limit : prods) { + if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) { + std::vector plan = tir::GetNumFusedLoops(loop_extent_prods, limit); + limit2plan[limit] = plan; + plan2limit[plan] = limit; + } + } + } + // Step 4. Remove the original plan and remove it + std::vector original_plan = + tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); + auto it = plan2limit.find(original_plan); + if (it != plan2limit.end()) { + plan2limit.erase(it); + } + // Step 5. Pick a new plan + int n_plans = plan2limit.size(); + if (n_plans == 0) { + return NullOpt; + } + it = plan2limit.begin(); + for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { + ++it; + } + int64_t limit = it->second; + // Step 6. Assemble a new trace + Array insts; + insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst.same_as(candidate.inst)) { + insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit)); + } else if (inst->kind->IsPostproc()) { + break; + } else { + insts.push_back(inst); + } + } + return Trace(insts, trace->decisions); +} + +Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + return Mutator(n); +} + +TVM_REGISTER_NODE_TYPE(MutateParallelNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc new file mode 100644 index 0000000000..be5bc1544b --- /dev/null +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::InstructionKind; +using tir::Trace; + +/*! + * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers + * \param decision The decision of Sample-Perfect-Tile + * \return The result of downcast + */ +std::vector DowncastDecision(const ObjectRef& decision) { + const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode); + return support::AsVector(GetRef>(arr)); +} + +/*! + * \brief Calculate the product of elements in an array + * \param array The array + * \return The product of elements in the array + */ +int64_t Product(const std::vector& array) { + int64_t result = 1; + for (int64_t x : array) { + result *= x; + } + return result; +} + +/*! \brief A mutator that mutates the tile size */ +class MutateTileSizeNode : public MutatorNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateTileSize"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode); + + public: + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! + * \brief Find a sample-perfect-tile decision in the trace + * \param trace The trace + * \param rand_state The random state + * \param inst The instruction selected + * \param decision The decision selected + * \return Whether a decision is found + */ +bool FindSamplePerfectTile(const Trace& trace, TRandState* rand_state, Instruction* inst, + std::vector* decision) { + static const InstructionKind& inst_sample_perfect_tile = + InstructionKind::Get("SamplePerfectTile"); + std::vector instructions; + std::vector> decisions; + instructions.reserve(trace->decisions.size()); + decisions.reserve(trace->decisions.size()); + for (const auto& kv : trace->decisions) { + const Instruction& inst = kv.first; + const ObjectRef& decision = kv.second; + if (!inst->kind.same_as(inst_sample_perfect_tile)) { + continue; + } + std::vector tiles = DowncastDecision(decision); + if (tiles.size() >= 2 || Product(tiles) >= 2) { + instructions.push_back(inst); + decisions.push_back(tiles); + } + } + int n = instructions.size(); + if (n > 0) { + int i = tir::SampleInt(rand_state, 0, n); + *inst = instructions[i]; + *decision = decisions[i]; + return true; + } + return false; +} + +Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { + Instruction inst; + std::vector tiles; + if (!FindSamplePerfectTile(trace, rand_state, &inst, &tiles)) { + return NullOpt; + } + int n_splits = tiles.size(); + // Step 1. Choose two loops, `x` and `y` + int x = tir::SampleInt(rand_state, 0, n_splits); + int y; + if (tiles[x] == 1) { + // need to guarantee that tiles[x] * tiles[y] > 1 + std::vector idx; + idx.reserve(n_splits); + for (int i = 0; i < n_splits; ++i) { + if (tiles[i] > 1) { + idx.push_back(i); + } + } + y = idx[tir::SampleInt(rand_state, 0, idx.size())]; + } else { + // sample without replacement + y = tir::SampleInt(rand_state, 0, n_splits - 1); + if (y >= x) { + ++y; + } + } + // make sure x < y + CHECK_NE(x, y); + if (x > y) { + std::swap(x, y); + } + // Step 2. Choose the new tile size + int64_t len_x, len_y; + if (y != n_splits - 1) { + // Case 1. None of x and y are innermost loop + do { + std::vector result = tir::SamplePerfectTile(rand_state, tiles[x] * tiles[y], 2); + len_x = result[0]; + len_y = result[1]; + } while (len_y == tiles[y]); + } else { + // Case 2. y is the innermost loop + std::vector len_y_space; + int64_t limit = Downcast(inst->attrs[1])->value; + int64_t prod = tiles[x] * tiles[y]; + for (len_y = 1; len_y <= limit; ++len_y) { + if (len_y != tiles[y] && prod % len_y == 0) { + len_y_space.push_back(len_y); + } + } + if (len_y_space.empty()) { + return NullOpt; + } + len_y = len_y_space[tir::SampleInt(rand_state, 0, len_y_space.size())]; + len_x = prod / len_y; + } + tiles[x] = len_x; + tiles[y] = len_y; + return trace->WithDecision(inst, support::AsArray(tiles), + /*remove_postproc=*/true); +} + +Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc new file mode 100644 index 0000000000..94e8348858 --- /dev/null +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if an instruction is annotate with + * `meta_schedule_unroll_explicit` or `meta_schedule_unroll_implicit` + * \param inst The instruction to be checked + * \return Whether the instruction is annotated + */ +bool IsAnnotateWithUnroll(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_unroll_explicit || + ann_key == attr::meta_schedule_unroll_implicit; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates auto unroll step */ +class MutateUnrollNode : public MutatorNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief A candidate to be mutated */ +struct MutateUnrollNode::Candidate { + /*! \brief The sampling instruction to be mutated */ + Instruction inst; + /*! \brief The probability */ + std::vector probs; + /*! \brief The decision made */ + int decision; +}; + +/*! + * \brief Find the Sample-Categorical instruction to be mutated that affects the maximal unroll step + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidates The mutation candidate + * \return Whether a decision is found + */ +bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, + MutateUnrollNode::Candidate* candidate) { + using tir::InstructionKind; + using tir::InstructionNode; + static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + std::unordered_map sample_insts; + std::vector ann_insts; + sample_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + sample_insts[var_rv] = inst.get(); + } else if (IsAnnotateWithUnroll(inst)) { + ann_insts.push_back(inst.get()); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode); + ICHECK(sample_insts.count(var_rv)); + const InstructionNode* sample_inst = sample_insts.at(var_rv); + ICHECK_EQ(sample_inst->attrs.size(), 2); + candidate->inst = GetRef(sample_inst); + candidate->decision = + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = + support::AsVector(Downcast>(sample_inst->attrs[1])); + return true; +} + +Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { + Candidate candidate; + if (!FindUnrollDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + if (candidate.probs.size() == 0) { + return NullOpt; + } + candidate.probs.erase(candidate.probs.begin() + candidate.decision); + int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); + if (result >= candidate.decision) { + result += 1; + } + return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); +} + +Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc new file mode 100644 index 0000000000..27383adf84 --- /dev/null +++ b/src/meta_schedule/mutator/mutator.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Mutator Mutator::PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Mutator(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMutatorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MutatorNode); +TVM_REGISTER_NODE_TYPE(PyMutatorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") + .set_body_method(&MutatorNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") + .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional { + TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); + return self->Apply(trace, &seed_); + }); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc new file mode 100644 index 0000000000..715815843a --- /dev/null +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Check if the loop is dynamic. */ +struct DynamicExtentFinder : private StmtVisitor { + public: + static bool Find(const IRModule& mod) { + DynamicExtentFinder finder; + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (const auto* prim_func = func.as()) { + finder(prim_func->body); + if (finder.found_) { + return true; + } + } + } + return false; + } + + private: + void VisitStmt_(const ForNode* loop) final { + if (!loop->extent->IsInstance()) { + found_ = true; + } else { + StmtVisitor::VisitStmt_(loop); + } + } + + void VisitStmt(const Stmt& stmt) final { + if (!found_) { + StmtVisitor::VisitStmt(stmt); + } + } + + bool found_ = false; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +/*! \brief Check if the IRModule has any loop with non-constant extent. */ +class DisallowDynamicLoopNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } + + static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; + TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); +}; + +Postproc Postproc::DisallowDynamicLoop() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") + .set_body_typed(Postproc::DisallowDynamicLoop); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc new file mode 100644 index 0000000000..ff069e2c68 --- /dev/null +++ b/src/meta_schedule/postproc/postproc.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Postproc Postproc::PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Postproc(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyPostprocNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(PostprocNode); +TVM_REGISTER_NODE_TYPE(PyPostprocNode); + +TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") + .set_body_method(&PostprocNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc new file mode 100644 index 0000000000..e256f4d0cd --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Parse instruction: sch.bind(..., "threadIdx.x") + * \param sch The schedule + * \param inst The instruction to be parsed + * \return NullOpt if parsing fails; Otherwise, the extent of thread axis + */ +Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst) { + static InstructionKind inst_kind_bind = InstructionKind::Get("Bind"); + if (!inst->kind.same_as(inst_kind_bind)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 1); + ICHECK_EQ(inst->attrs.size(), 1); + String thread_axis = Downcast(inst->attrs[0]); + if (thread_axis != "threadIdx.x") { + return NullOpt; + } + return Downcast(sch->Get(Downcast(inst->inputs[0]))->extent); +} + +/*! + * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch) + * \param sch The schedule + * \param inst The instruction to be parsed + * \param vector_lane The length of vector lane in vectorized cooperative fetching + * \return NullOpt if parsing fails; Otherwise, the annotated block + */ +Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) { + static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_kind_annotate)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 2); + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + if (ann_key != attr::meta_schedule_cooperative_fetch) { + return NullOpt; + } + *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; + return Downcast(inst->inputs[0]); +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching + * in loop bindings. + */ +class RewriteCooperativeFetchNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); +}; + +bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { + using tir::BlockRV; + using tir::Instruction; + using tir::LoopRV; + using tir::Schedule; + using tir::Trace; + Trace trace = sch->trace().value(); + int thread_extent = -1; + int vector_lane = -1; + std::vector> tasks; + for (const Instruction& inst : trace->insts) { + if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst)) { + thread_extent = new_thread_extent.value()->value; + } + if (Optional block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) { + ICHECK_NE(thread_extent, -1); + if (vector_lane > 1) { + tasks.push_back([thread_extent, vector_lane, sch, block = block_rv.value()]() -> void { + LoopRV fused = sch->GetLoops(block).back(); + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent), // + Integer(vector_lane)}); + sch->Vectorize(split[2]); + sch->Bind(split[1], "threadIdx.x"); + }); + } else { + tasks.push_back([thread_extent, sch, block = block_rv.value()]() -> void { + LoopRV fused = sch->GetLoops(block).back(); + Array split = sch->Split(fused, {NullOpt, Integer(thread_extent)}); + sch->Bind(split[1], "threadIdx.x"); + }); + } + } + } + for (auto&& task : tasks) { + task(); + } + return true; +} + +Postproc Postproc::RewriteCooperativeFetch() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") + .set_body_typed(Postproc::RewriteCooperativeFetch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc new file mode 100644 index 0000000000..4eca068e17 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check whether the block/loop has any annotation + * \param sref The sref of block/loop + * \return Whether the block/loop has any annotation + */ +inline bool HasAnnOrBinding(const ForNode* loop) { + return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty(); +} + +class StrideExtractor : public StmtExprVisitor { + public: + static int64_t Extract(const PrimExpr& expr, const Var& var) { + StrideExtractor extractor(var); + extractor.VisitExpr(expr); + return extractor.strides_[expr.get()]; + } + + private: + explicit StrideExtractor(const Var& var) : var_(var) {} + + void VisitExpr_(const MulNode* node) final { + StmtExprVisitor::VisitExpr_(node); + + if (const auto* a = node->a.as()) { + if (strides_.count(node->b.get())) { + strides_[node] = strides_[node->b.get()] * a->value; + } + } else if (const auto* b = node->b.as()) { + if (strides_.count(node->a.get())) { + strides_[node] = strides_[node->a.get()] * b->value; + } + } + } + + void VisitExpr_(const AddNode* node) final { + StmtExprVisitor::VisitExpr_(node); + int64_t stride_a, stride_b; + if (strides_.count(node->a.get())) { + stride_a = strides_[node->a.get()]; + } else { + stride_a = INT64_MAX; + } + if (strides_.count(node->b.get())) { + stride_b = strides_[node->b.get()]; + } else { + stride_b = INT64_MAX; + } + if (stride_a != INT64_MAX || stride_b != INT64_MAX) { + strides_[node] = std::min(stride_a, stride_b); + } + } + + void VisitExpr_(const VarNode* node) final { + if (node == var_.get()) { + strides_[node] = 1; + } + } + + const Var& var_; + std::unordered_map strides_; +}; + +struct ParsedAnnotation { + int max_parallel_extent; + int max_vectorize_extent; + int unroll_explicit; + int unroll_implicit; + int num_parallel_loops; + int num_vectorize_loops; +}; + +bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { + bool found = false; + *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; + for (const auto& ann : block->annotations) { + if (ann.first == attr::meta_schedule_parallel) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_parallel_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_vectorize) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_vectorize_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_explicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_explicit = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_implicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_implicit = imm->value; + } + } + } + return found; +} + +void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { + if (parsed.max_parallel_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_parallel); + } + if (parsed.max_vectorize_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_vectorize); + } + if (parsed.unroll_explicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); + } + if (parsed.unroll_implicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); + } +} + +void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, + const Array& loop_rvs, ParsedAnnotation* parsed) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { + return; + } + int n_loops = loop_rvs.size(); + if (n_loops == 0) { + parsed->max_parallel_extent = -1; + parsed->max_vectorize_extent = -1; + return; + } + // Extract loop_srefs, and calculate the iterator types + Array loop_srefs; + std::vector loop_types; + { + loop_srefs.reserve(n_loops); + loop_types.reserve(n_loops); + for (const LoopRV& loop_rv : loop_rvs) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + loop_types.push_back(GetLoopIterType(loop_srefs.back())); + } + } + // check the maximal number of axes that are vectorizable (contiguous memory access) + BlockRealize realize = GetBlockRealize(sch->state(), block_sref); + Array buffer_access(realize->block->reads); + buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), + realize->block->writes.end()); + std::unordered_map binding_map; + for (size_t i = 0; i < realize->iter_values.size(); i++) { + binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i]; + } + int max_fusible = INT32_MAX; + // for each block read/write, get the strides of the loop vars and find the fusible + // (vectorizable) axes + for (const BufferRegion& access : buffer_access) { + int fusible = 0; + std::vector strides; + // get strides for each loop var + for (const StmtSRef& loop_sref : loop_srefs) { + int64_t stride = 0, buffer_stride = 1; + const auto* var = loop_sref->StmtAs(); + arith::Analyzer analyzer; + for (int i = access->region.size() - 1; i >= 0; i--) { + PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); + int64_t coef = StrideExtractor::Extract(idx, var->loop_var); + if (coef != 0) { + stride = coef * buffer_stride; + break; + } + buffer_stride *= access->buffer->shape[i].as()->value; + } + strides.push_back(stride); + } + int prev_used_iter = -1; + // check the number of fusible loops + for (int i = strides.size() - 1; i >= 0; i--) { + if (strides[i] == 0) { + // not used in the buffer access, safe to fuse + fusible++; + continue; + } else if (prev_used_iter == -1) { + // the stride of last axis is not 1 means the memory access is not contiguous + if (strides[i] != 1) { + break; + } + fusible++; + prev_used_iter = i; + } else { + // contiguous memory access + const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs(); + int64_t prev_used_iter_extent = prev_loop->extent.as()->value; + if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { + fusible++; + prev_used_iter = i; + } else { + break; + } + } + } + max_fusible = std::min(max_fusible, fusible); + } + // Calculate the parallelize extent + if (parsed->max_parallel_extent != -1) { + int max_extent = parsed->max_parallel_extent; + int& num_fusible = parsed->num_parallel_loops = 0; + int64_t prod_extent = 1; + for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Then we can fuse it in + ++num_fusible; + // Check if we need to break + prod_extent *= *extent; + if (prod_extent > max_extent || !IsSingleStmt(loop->body)) { + break; + } + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Calculate the vectorize extent + if (parsed->max_vectorize_extent != -1) { + int max_extent = parsed->max_vectorize_extent; + int& num_fusible = parsed->num_vectorize_loops = 0; + int64_t prod_extent = 1; + for (int i = n_loops - 1; + i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Cannot vectorize reduce axis + if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) { + break; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Check if the extent is still in a good range + prod_extent *= *extent; + if (prod_extent > max_extent) { + break; + } + ++num_fusible; + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { + parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // + n_loops - parsed->num_vectorize_loops); + } +} + +bool FindAnnotateRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + Block block = Downcast(prim_func->body)->block; + if (ParseAnnotation(block, parsed)) { + *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); + RemoveParsedAnn(sch, *root_rv, *parsed); + return true; + } + } + } + return false; +} + +void RewriteParallel(const Schedule& sch, int n, Array* loop_rvs) { + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); + sch->Parallel(fused); + for (int i = 0; i < n; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteVectorize(const Schedule& sch, int n, Array* loop_rvs) { + int n_loops = loop_rvs->size(); + LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); + sch->Vectorize(fused); + for (int i = n_loops - n; i < n_loops; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const LoopRV& loop) { + if (max_step > 0) { + sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); + sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); + } +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +class RewriteParallelVectorizeUnrollNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const Schedule& sch) final { + using tir::BlockRV; + using tir::LoopRV; + tir::ParsedAnnotation parsed_root; + BlockRV root_rv{nullptr}; + while (tir::FindAnnotateRootBlock(sch, &parsed_root, &root_rv)) { + for (BlockRV block_rv : sch->GetChildBlocks(root_rv)) { + Array loop_rvs = sch->GetLoops(block_rv); + if (loop_rvs.empty()) { + continue; + } + tir::ParsedAnnotation parsed = parsed_root; + tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } + // AutoUnroll + if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { + ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + int unroll_explicit = parsed.unroll_explicit != -1; + int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; + tir::RewriteUnroll(sch, unroll_explicit, max_step, loop_rvs[0]); + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); +}; + +Postproc Postproc::RewriteParallelVectorizeUnroll() { + ObjectPtr n = + make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") + .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc new file mode 100644 index 0000000000..d1a5492361 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The visitor that finds all the reduction block to be decomposed */ +struct ReductionBlockFinder : private StmtVisitor { + public: + /*! \brief Find all the reduction blocks that should be decomposed */ + static std::vector> Find(const ScheduleState& self) { + std::vector> results; + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + ReductionBlockFinder finder; + finder(prim_func->body); + for (const BlockNode* block : finder.results_) { + results.emplace_back(self->stmt2ref.at(block), g_var->name_hint); + } + } + } + return results; + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) { + thread_bound_loop_vars_.insert(loop->loop_var.get()); + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) { + results_.push_back(realize->block.get()); + } + StmtVisitor::VisitStmt_(realize); + } + + bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const { + if (thread_bound_loop_vars_.empty()) { + return true; + } + auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; + const BlockNode* block = realize->block.get(); + int n = block->iter_vars.size(); + for (int i = 0; i < n; ++i) { + IterVar iter_var = block->iter_vars[i]; + PrimExpr binding = realize->iter_values[i]; + if (iter_var->iter_type == tir::kCommReduce) { + if (UsesVar(binding, f_find)) { + return false; + } + } + } + return true; + } + + /*! \brief The results of the collection */ + std::vector results_; + /*! \brief Loop variables that are bound to threads */ + std::unordered_set thread_bound_loop_vars_; +}; + +/*! + * \brief Find the innermost loop that could be decomposed to + * \param block_sref The block to be decomposed + * \return The index of the innermost loop that could be decomposed + */ +int FindDecomposePoint(const StmtSRef& block_sref) { + Array loop_srefs = GetLoops(block_sref); + int n = loop_srefs.size(); + for (int i = 0; i < n; ++i) { + if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { + return i; + } + } + return -1; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Rewrite reduction block by moving the init block out */ +class RewriteReductionBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); +}; + +bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { + for (;;) { + std::vector> results = + tir::ReductionBlockFinder::Find(sch->state()); + int rewritten = 0; + for (const auto& kv : results) { + const tir::StmtSRef& block_sref = kv.first; + const String& global_var_name = kv.second; + int decompose_point = tir::FindDecomposePoint(block_sref); + if (decompose_point == -1) { + continue; + } + tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + ++rewritten; + } + if (rewritten == 0) { + break; + } + } + return true; +} + +Postproc Postproc::RewriteReductionBlock() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") + .set_body_typed(Postproc::RewriteReductionBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc new file mode 100644 index 0000000000..2608dce19a --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The rewrite type for an unbound block */ +enum class BindType : int32_t { + /*! \brief No additional thread binding is needed */ + kNoBind = 0, + /*! \brief Need to bind to blockIdx */ + kBindBlock = 1, + /*! \brief Need to bind to both blockIdx and threadIdx */ + kBindBlockThread = 2, +}; + +/*! + * \brief Check the combination of bindings to be added to the block + * \param block_sref The block to be checked + * \param fuse_first_num The number of loops to be fused + * \return The type of binding to be added to the block + */ +BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) { + Array loops = tir::GetLoops(block_sref); + int n = loops.size(); + if (n == 0) { + return BindType::kNoBind; + } + int i_block_idx = -1; + int i_thread_idx = -1; + int i_multi_child = -1; + for (int i = 0; i < n; ++i) { + const StmtSRef& loop_sref = loops[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + if (i_block_idx == -1) { + i_block_idx = i; + } + } + if (IsThreadIdx(thread_scope)) { + if (i_thread_idx == -1) { + i_thread_idx = i; + } + } + if (!IsSingleStmt(loop->body)) { + if (i_multi_child == -1) { + i_multi_child = i + 1; + } + } + } + if (i_multi_child == -1) { + i_multi_child = n; + } + if (i_block_idx != -1 && i_thread_idx != -1) { + return BindType::kNoBind; + } else if (i_block_idx != -1 && i_thread_idx == -1) { + ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + throw; + } else if (i_block_idx == -1 && i_thread_idx != -1) { + *fuse_first_num = std::min(i_multi_child, i_thread_idx); + return BindType::kBindBlock; + } else { // i_block_idx == -1 && i_thread_idx == -1 + *fuse_first_num = i_multi_child; + return BindType::kBindBlockThread; + } +} + +/*! \brief Find all the blocks that are not bound */ +class UnboundBlockFinder : private StmtVisitor { + public: + static std::vector> Find(const ScheduleState& self) { + UnboundBlockFinder finder(self); + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + finder.global_var_name_ = g_var->name_hint; + finder(Downcast(prim_func->body)->block->body); + } + } + return std::move(finder.blocks_); + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + ++n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + ++n_thread_idx_; + } + if (n_block_idx_ == 0 || n_thread_idx_ == 0) { + StmtVisitor::VisitStmt_(loop); + } + if (IsBlockIdx(thread_scope)) { + --n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + --n_thread_idx_; + } + } + + void VisitStmt_(const BlockNode* block) final { + blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_); + } + + explicit UnboundBlockFinder(const ScheduleState& self) + : self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {} + + /*! \brief The schedule state */ + const ScheduleState& self_; + /*! \brief The list of unbound blocks */ + std::vector> blocks_; + /*! \brief The number of blockIdx above the current stmt */ + int n_block_idx_; + /*! \brief The number of threadIdx above the current stmt */ + int n_thread_idx_; + /*! \brief The name of the global var */ + String global_var_name_; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Add thread binding to unbound blocks */ +class RewriteUnboundBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context->target.defined()) << "ValueError: target is not defined"; + Optional warp_size = context->target.value()->GetAttr("thread_warp_size"); + CHECK(warp_size.defined()) << "ValueError: missing attribute `thread_warp_size` in the target"; + this->warp_size_ = warp_size.value(); + } + + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + public: + /*! \brief The cached warp size from Target */ + int warp_size_ = -1; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `warp_size_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode); +}; + +bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { + using tir::BlockRV; + using tir::LoopRV; + using tir::Schedule; + ICHECK_NE(this->warp_size_, -1); + std::vector> unbound_blocks = + tir::UnboundBlockFinder::Find(sch->state()); + for (const auto& kv : unbound_blocks) { + tir::StmtSRef block_sref = kv.first; + String global_var_name = kv.second; + int fuse_first_num = 0; + tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num); + if (bind_type == tir::BindType::kNoBind) { + continue; + } + BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num}); + if (bind_type == tir::BindType::kBindBlock) { + sch->Bind(fused, "blockIdx.x"); + } else if (bind_type == tir::BindType::kBindBlockThread) { + Array splits = sch->Split(fused, {NullOpt, Integer(this->warp_size_)}); + ICHECK_EQ(splits.size(), 2); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + } + } + return true; +} + +Postproc Postproc::RewriteUnboundBlock() { + ObjectPtr n = make_object(); + n->warp_size_ = -1; + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") + .set_body_typed(Postproc::RewriteUnboundBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc new file mode 100644 index 0000000000..5a768d705e --- /dev/null +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Verify the correctness of the generated GPU code. */ +Integer Extract(const Target& target, const char* name) { + ICHECK(target.defined()); + if (Optional v = target->GetAttr(name)) { + return v.value(); + } + LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + throw; +} + +/*! \brief Verify the correctness of the generated GPU code. */ +class VerifyGPUCodeNode : public PostprocNode { + public: + Map target_constraints_{nullptr}; + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->target_constraints_ = Map{ + {"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")}, + {"max_local_memory_per_block", Extract(target, "registers_per_block")}, + {"max_threads_per_block", Extract(target, "max_threads_per_block")}, + {"max_vthread", Integer(8)}, + {"max_vector_bytes", Integer(16)}}; + } + + bool Verify(const IRModule& mod) const { + for (const auto& kv : mod->functions) { + if (const auto* prim_func = kv.second.as()) { + if (!tir::VerifyGPUCode(GetRef(prim_func), this->target_constraints_)) { + return false; + } + } + } + return true; + } + + bool Apply(const tir::Schedule& sch) final { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + IRModule lowered{nullptr}; + try { + lowered = LowerPrimFunc(GetRef(prim_func), g_var->name_hint); + } catch (const dmlc::Error& e) { + return false; + } + if (!Verify(lowered)) { + return false; + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; + TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); +}; + +Postproc Postproc::VerifyGPUCode() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc new file mode 100644 index 0000000000..ae8fa1f73c --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The type of inline to be performed on a specific block */ +enum class InlineType : int32_t { + /*! \brief No inline opportunity */ + kNoInline = 0, + /*! \brief Inline the block into its consumer */ + kInlineIntoConsumer = 1, + /*! \brief Inline the block into its producer */ + kInlineIntoProducer = 2, +}; + +/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ +class AutoInlineNode : public ScheduleRuleNode { + public: + /*! \brief Checks if the specific block should be inlined */ + inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv); + + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + InlineType inline_type = CheckInline(sch, block_rv); + if (inline_type == InlineType::kInlineIntoConsumer) { + sch->ComputeInline(block_rv); + } else if (inline_type == InlineType::kInlineIntoProducer) { + sch->ReverseComputeInline(block_rv); + } + return {sch}; + } + + public: + /*! \brief If allows to inline a block into its producer */ + bool into_producer; + /*! \brief If allows to inline a block into its consumer */ + bool into_consumer; + /*! \brief If it only allows to inline into a block generated by cache_read/write */ + bool into_cache_only; + /*! \brief Always inline constant tensors */ + bool inline_const_tensor; + /*! \brief Always disallow if-then-else-like constructs */ + bool disallow_if_then_else; + /*! \brief Always require the read-to-write mapping to be injective to do auto inline */ + bool require_injective; + /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ + bool require_ordered; + /*! \brief The operators that are disallowed in auto inline */ + Array disallow_op; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("into_producer", &into_producer); + v->Visit("into_consumer", &into_consumer); + v->Visit("into_cache_only", &into_cache_only); + v->Visit("inline_const_tensor", &inline_const_tensor); + v->Visit("disallow_if_then_else", &disallow_if_then_else); + v->Visit("require_injective", &require_injective); + v->Visit("require_ordered", &require_ordered); + v->Visit("disallow_op", &disallow_op); + } + + static constexpr const char* _type_key = "meta_schedule.AutoInline"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); +}; + +inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { + using namespace tvm::tir; + StmtSRef block_sref = sch->GetSRef(block_rv); + ScheduleState state = sch->state(); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + BlockRealize realize = GetBlockRealize(state, block_sref); + // Cond 1. The block has only one write buffer + if (block->writes.size() != 1) { + return InlineType::kNoInline; + } + // Cond 2. The block is a spatial block + if (!IsSpatial(block_sref)) { + return InlineType::kNoInline; + } + // Cond 3. For a block that generates a constant tensor, ignore all other conditions + if (inline_const_tensor && block->reads.empty()) { + return InlineType::kInlineIntoConsumer; + } + // Cond 4. The block doesn't contain any disallowed operators + if (!disallow_op.empty() && HasOp(realize, disallow_op)) { + return InlineType::kNoInline; + } + // Cond 5. The block doesn't have any if-then-else-like constructs + if (disallow_if_then_else && HasIfThenElse(realize)) { + return InlineType::kNoInline; + } + // Cond 6. The mapping from read indices to write indices are injective and ordered + if (require_injective || require_ordered) { + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool injective, ordered; + constexpr auto _ = std::ignore; + std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_, + /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region); + if (require_injective && injective == false) { + return InlineType::kNoInline; + } + if (require_ordered && ordered == false) { + return InlineType::kNoInline; + } + } + } + // Last cond: Check inline into the spatial consumer or the spatial producer + if (into_consumer) { + Array consumer_srefs = GetConsumers(state, block_sref); + if (consumer_srefs.size() == 1 && IsSpatial(consumer_srefs[0])) { + if (!into_cache_only || + tir::GetAnn(consumer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) { + if (CanComputeInline(state, block_sref)) { + return InlineType::kInlineIntoConsumer; + } + } + } + } + if (into_producer) { + Array producer_srefs = GetProducers(state, block_sref); + if (producer_srefs.size() == 1 && IsSpatial(producer_srefs[0])) { + if (!into_cache_only || + tir::GetAnn(producer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) { + if (CanReverseComputeInline(state, block_sref)) { + return InlineType::kInlineIntoProducer; + } + } + } + } + return InlineType::kNoInline; +} + +ScheduleRule ScheduleRule::AutoInline(bool into_producer, // + bool into_consumer, // + bool into_cache_only, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op) { + ObjectPtr n = make_object(); + n->into_producer = into_producer; + n->into_consumer = into_consumer; + n->into_cache_only = into_cache_only; + n->inline_const_tensor = inline_const_tensor; + n->disallow_if_then_else = disallow_if_then_else; + n->require_injective = require_injective; + n->require_ordered = require_ordered; + n->disallow_op.clear(); + if (disallow_op.defined()) { + Array op_names = disallow_op.value(); + n->disallow_op.reserve(op_names.size()); + for (const String& op_name : op_names) { + n->disallow_op.push_back(Op::Get(op_name)); + } + } + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(AutoInlineNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") + .set_body_typed(ScheduleRule::AutoInline); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc new file mode 100644 index 0000000000..a74c2e05cf --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { +/*! + * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction + * buffers' dimensions as -1 + * \param block_sref The block to be processed + * \return The buffer dimensions for all the read buffers of a block, except for reduction buffers + * \note The method is not designed for generic analysis and relies on assumptions in the scenario + * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header + */ +std::vector GetReadBufferNDims(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + int n = block->reads.size(); + std::vector results(n, -1); + for (int i = 0; i < n; ++i) { + const BufferNode* read_buffer = block->reads[i]->buffer.get(); + if (read_buffer != write_buffer) { + results[i] = read_buffer->shape.size(); + } + } + return results; +} +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::ExprRV; +using tir::IterVarType; +using tir::LoopRV; +using tir::Schedule; + +/*! + * \brief Configuration of data reuse type: + * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. + * 1) kMayReuse: reuse is allowed, but no reuse is explored. + * 2) kMustReuse: reuse is allowed and no reuse is not explored. + */ +enum class ReuseType : int32_t { + kNoReuse = 0, + kMayReuse = 1, + kMustReuse = 2, +}; + +/*! + * \brief Converts a string to ReuseType. + * \param str The string to be converted. + * \return The converted ReuseType. + */ +ReuseType Str2ReuseType(const String& str) { + if (str == "no") { + return ReuseType::kNoReuse; + } else if (str == "may") { + return ReuseType::kMayReuse; + } else if (str == "must") { + return ReuseType::kMustReuse; + } else { + LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; + throw; + } +} + +/*! \brief Configuration of data reuse patterns */ +struct ReuseConfig { + /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ + ReuseType req; + /*! \brief Which levels are caching stage inserted at */ + std::vector levels; + /*! \brief The storage scope */ + String scope; + + /*! \brief Default constructor: no data reuse */ + ReuseConfig() : req(ReuseType::kNoReuse) {} + + /*! \brief Construct from a configuration dictionary */ + explicit ReuseConfig(const Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { + ICHECK_EQ(config.size(), 3); + } +}; + +/*! \brief The state of auto scheduling for the multi-level tiling rule */ +struct State { + /*! \brief The schedule to date */ + Schedule sch; + /*! \brief The block to be tiled */ + BlockRV block_rv; + /*! \brief The write cache */ + Optional write_cache; + /*! \brief Indicating if the write cache is generated by cache_write */ + bool write_cache_is_added; + /*! \brief The loop tiles */ + Array> tiles; + + /*! \brief Default constructor */ + explicit State(Schedule sch, BlockRV block_rv, Optional write_cache = NullOpt, + bool write_cache_is_added = false, Array> tiles = {}) + : sch(sch), + block_rv(block_rv), + write_cache(write_cache), + write_cache_is_added(write_cache_is_added), + tiles(tiles) {} +}; + +/*! + * \brief Helper to apply a sub-rule to a list of auto scheduling states + * \tparam FLambda The type of the sub-rule functor + * \param states The list of states to be applied + * \return The list of states after applying the sub-rule + */ +template +std::vector SubRule(std::vector states, FLambda sub_rule) { + std::vector results; + for (auto&& state : states) { + std::vector next = sub_rule(std::move(state)); + results.insert(results.end(), + std::make_move_iterator(next.begin()), // + std::make_move_iterator(next.end())); + } + return results; +} + +/*! + * \brief The mega rule: multi-level tiling with data reuse + */ +class MultiLevelTilingNode : public ScheduleRuleNode { + public: + // SubRule 1. add write cache + inline std::vector AddWriteReuse(State state) const; + // SubRule 2. tile the loop nest + inline std::vector TileLoopNest(State state) const; + // SubRule 3. add read cache + inline std::vector AddReadReuse(State state) const; + // SubRule 4. fuse write cache + inline std::vector FuseWriteReuse(State state) const; + // Do nothing; Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Entry of the mega rule; Inherited from ScheduleRuleNode + Array Apply(const Schedule& sch, const BlockRV& block_rv) final { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; + } + std::vector states{State(sch, block_rv)}; + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return FuseWriteReuse(state); }); + Array results; + for (auto&& state : states) { + results.push_back(std::move(state.sch)); + } + return results; + } + + public: + /*! + * \brief The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + */ + String structure; + /*! \brief For each level of tiles, which thread axis it is bound to */ + Array tile_binds; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The length of vector lane in vectorized cooperative fetching */ + int vector_load_max_len; + /*! \brief Data reuse configuration for reading */ + ReuseConfig reuse_read_; + /*! \brief Data reuse configuration for writing */ + ReuseConfig reuse_write_; + /*! \brief The indices of spatial tiles in `structure` */ + std::vector s_indices_; + /*! \brief The indices of reduction tiles in `structure` */ + std::vector r_indices_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("structure", &structure); + v->Visit("tile_binds", &tile_binds); + v->Visit("max_innermost_factor", &max_innermost_factor); + v->Visit("vector_load_max_len", &vector_load_max_len); + // `reuse_read_` is not visited + // `reuse_write_` is not visited + // `s_indices_` is not visited + // `r_indices_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); +}; + +inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + // Case 1. If the write cache is already there, we don't need to add another. + if (config.req == ReuseType::kMayReuse) { + Array consumer_rvs = state.sch->GetConsumers(state.block_rv); + if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { + state.write_cache = consumer_rvs[0]; + state.write_cache_is_added = false; + return {std::move(state)}; + } + } + std::vector results; + results.reserve(2); + // Case 2. No write cache is added + if (config.req == ReuseType::kMayReuse) { + State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv, + /*write_cache=*/NullOpt, + /*write_cache_is_added=*/false); + new_state.sch->Seed(state.sch->ForkSeed()); + results.emplace_back(std::move(new_state)); + } + // Case 3. Add one write cache + BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, + /*storage_scope=*/config.scope); + state.write_cache = write_cache; + { + tir::Annotate(state.sch->state(), state.sch->GetSRef(write_cache), // + tir::attr::meta_schedule_cache_type, // + Integer(tir::attr::meta_schedule_cache_type_write)); + } + + state.write_cache_is_added = true; + results.emplace_back(std::move(state)); + return results; +} + +inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const { + Schedule& sch = state.sch; + const BlockRV& block_rv = state.block_rv; + // Step 1. Assuming trivial binding, pair the loops and their iter-var-types + Array loops = sch->GetLoops(block_rv); + std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv)); + ICHECK_EQ(loops.size(), iter_types.size()); + // Step 2. For each loop axis, tile it + std::vector> tiles(s_indices_.size() + r_indices_.size()); + for (int i = 0, n = loops.size(); i < n; ++i) { + const std::vector* idx = nullptr; + if (iter_types[i] == IterVarType::kDataPar) { + idx = &s_indices_; + } else if (iter_types[i] == IterVarType::kCommReduce) { + idx = &r_indices_; + } else { + continue; + } + // Do the split + int n_tiles = idx->size(); + LoopRV loop = loops[i]; + Array factors = sch->SamplePerfectTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*max_innermost_factor=*/max_innermost_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + } + } + // Step 3. Reorder to organize the tiles + sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); + // Step 4. Bind the tiles to threads + int n_binds = std::min(tile_binds.size(), tiles.size()); + for (int i = 0; i < n_binds; ++i) { + LoopRV fused = sch->Fuse(tiles[i]); + sch->Bind(fused, tile_binds[i]); + tiles[i] = {fused}; + } + state.tiles = Array>{tiles.begin(), tiles.end()}; + return {state}; +} + +inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const { + const ReuseConfig& config = this->reuse_read_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + ICHECK(config.req != ReuseType::kMayReuse); + const BlockRV& block_rv = state.block_rv; + std::vector results; + results.reserve(config.levels.size()); + for (int level : config.levels) { + Schedule sch = state.sch->Copy(); + sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = state.tiles[level - 1].back(); + // Enumerate all buffers that are read but not written + std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { + int buffer_ndim = read_buffer_ndims[i]; + if (buffer_ndim == -1) { + continue; + } + // Do cache_read + BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope); + { + tir::Annotate(sch->state(), sch->GetSRef(cache_read_block), // + tir::attr::meta_schedule_cache_type, + Integer(tir::attr::meta_schedule_cache_type_read)); + } + // Insert cache_read block to the proper place + sch->ComputeAt(cache_read_block, loop_rv, true); + // Fuse the iterators of the cache_read + Array buffer_loops = sch->GetLoops(cache_read_block); + LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); + // Annotate cooperative fetching + if (vector_load_max_len > 0) { + // cooperative fetch + vectorized loading + // Split into inner and outer, vectorize the inner loop + Array factors = sch->SamplePerfectTile(fused, 2, vector_load_max_len); + // Add cooperative fetching + sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, factors[1]); + } + } + State new_state = state; + new_state.sch = sch; + results.push_back(std::move(new_state)); + } + return results; +} + +inline std::vector MultiLevelTilingNode::FuseWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + // If the only-consumer does not exist, or is not elementwise, then do not do fusion + if (!state.write_cache.defined()) { + return {std::move(state)}; + } + std::vector results; + // Special case. + // Stages added by `cache_write` must be fused at some level, otherwise it has no benefit. + // On the other hand, If the consumer stage is not added by `cache_write`, + // we may choose not to fuse by setting `must_cache_write = False` + if (!state.write_cache_is_added && config.req != ReuseType::kMustReuse) { + results.push_back(state); + } + BlockRV consumer = state.write_cache.value(); + // Enumerate the level of tile to be fused at + for (int level : config.levels) { + Schedule sch = state.sch->Copy(); + sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = state.tiles[level - 1].back(); + sch->ReverseComputeAt(consumer, loop_rv, true); + State new_state = state; + new_state.sch = sch; + results.push_back(std::move(new_state)); + } + return results; +} + +// Constructor + +ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, + Optional max_innermost_factor, + Optional vector_load_max_len, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_max_len = vector_load_max_len.value_or(Integer(-1))->value; + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTiling") + .set_body_typed(ScheduleRule::MultiLevelTiling); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc new file mode 100644 index 0000000000..b7100be925 --- /dev/null +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +bool IsRootWithNoAnnotation(const Schedule& sch, const BlockRV& block_rv) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (block_sref->parent != nullptr) { + return false; + } + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return block->annotations.empty(); +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + if (this->max_jobs_per_core != -1) { + Target target = context->target.value(); + this->max_parallel_extent_ = GetTargetNumCores(target) * max_jobs_per_core; + } + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { + if (!tir::IsRootWithNoAnnotation(sch, root_rv)) { + return {sch}; + } + // Parallelization + if (max_jobs_per_core != -1) { + sch->Annotate(root_rv, tir::attr::meta_schedule_parallel, + Integer(this->max_parallel_extent_)); + } + // Vectorization + if (max_vectorize_extent != -1) { + sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent)); + } + // Unroll + if (!unroll_max_steps.empty()) { + int n = unroll_max_steps.size(); + double prob = 1.0 / n; + Array probs(n, FloatImm(DataType::Float(64), prob)); + PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); + if (unroll_explicit) { + sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); + } else { + sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_implicit, max_step); + } + } + return {sch}; + } + + public: + /*! + * \brief The maximum number of jobs to be launched per CPU core. + * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int64_t max_jobs_per_core; + /*! + * \brief The maximum extent to be vectorized. It sets the uplimit of the CPU vectorization. + * Use -1 to disable vectorization. + */ + int max_vectorize_extent; + /*! + * \brief brief description The maximum number of unroll steps to be done. + * Use an empty array to disable unroll. + */ + Array unroll_max_steps; + /*! \brief Whether to explicitly unroll the loop, or just add a unroll pragma. */ + bool unroll_explicit; + /*! \brief The number of cores in CPU. */ + int64_t max_parallel_extent_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + v->Visit("max_vectorize_extent", &max_vectorize_extent); + v->Visit("unroll_max_steps", &unroll_max_steps); + v->Visit("unroll_explicit", &unroll_explicit); + // `max_parallel_extent_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, + int max_vectorize_extent, + Array unroll_max_steps, + bool unroll_explicit) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + n->max_vectorize_extent = max_vectorize_extent; + n->unroll_max_steps = unroll_max_steps; + n->unroll_explicit = unroll_explicit; + n->max_parallel_extent_ = -1; + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll") + .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc new file mode 100644 index 0000000000..1757f650aa --- /dev/null +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class RandomComputeLocationNode : public ScheduleRuleNode { + public: + bool IsFreeBlock(const tir::Schedule sch, const tir::StmtSRef& block_sref) const { + if (block_sref->parent == nullptr) { + return false; + } + if (!tir::IsSubrootBlock(sch->state(), block_sref)) { + return false; + } + tir::ScheduleState state = sch->state(); + if (!tir::IsCompleteBlock(state, block_sref, + tir::GetScopeRoot(state, block_sref, false, false))) { + return false; + } + Array loop_srefs = tir::GetLoops(block_sref); + for (const tir::StmtSRef& loop_sref : loop_srefs) { + if (!tir::HasSingleChild(loop_sref)) { + return false; + } + } + Array binds = tir::GetBlockRealize(state, block_sref)->iter_values; + for (const PrimExpr& bind : binds) { + if (!bind->IsInstance() && !bind->IsInstance()) { + return false; + } + } + return true; + } + + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + if (!IsFreeBlock(sch, block_sref)) { + return {sch}; + } + Array consumers = sch->GetConsumers(block_rv); + if (consumers.size() != 1) { + return {sch}; + } + tir::BlockRV consumer = consumers[0]; + // Try to compute `block_rv` at `consumer` + for (;;) { + tir::LoopRV compute_at_loc = sch->SampleComputeLocation(consumer); + try { + sch->ComputeAt(block_rv, compute_at_loc, true); + } catch (const dmlc::Error& e) { + // ComputeAt fails, cleanup the following before re-try: + // 1) trace: instruction & decisions + // 2) sym_tab + sch->trace().value()->Pop(); + sch->RemoveRV(compute_at_loc); + continue; + } + break; + } + return {sch}; + } + + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::RandomComputeLocation() { + ObjectPtr n = make_object(); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") + .set_body_typed(ScheduleRule::RandomComputeLocation); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc new file mode 100644 index 0000000000..f80f684daf --- /dev/null +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +ScheduleRule ScheduleRule::PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return ScheduleRule(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); +TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") + .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") + .set_body_method(&ScheduleRuleNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") + .set_body_typed(ScheduleRule::PyScheduleRule); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc new file mode 100644 index 0000000000..16c8bae9b4 --- /dev/null +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -0,0 +1,729 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ + CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ + << "but get `" << #p << " = " << (p) << '\''; + +namespace tvm { +namespace meta_schedule { + +/**************** Data Structure ****************/ + +/*! + * \brief The struct to store schedule, trace and its score. + * \note The trace is available by visiting the schedule's trace method. + */ +struct CachedTrace { + /*! \brief The schedule the trace creates. */ + tir::Schedule sch{nullptr}; + /*! \brief The normalized score, the higher the better. */ + double score; + + /*! \brief Default constructor. */ + CachedTrace() = default; + /*! + * \brief Constructor from Schedule and score. + * \param sch The given Schedule, which can be used to obtain the trace. + * \param score The predicted normalized score, -1.0 if score is not assigned yet. + */ + explicit CachedTrace(const tir::Schedule& sch, double score) : sch(sch), score(score) {} + /*! \brief Reload the operator < for CachedTrace. */ + friend inline bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { + return lhs.score > rhs.score; + } +}; + +/*! + * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. + * \note It maintains a min heap in terms of `CachedTrace::score`. Therefore, when + * overflow happens, the element evicted is the one with the min `CachedTrace::score`. + * As time goes, the elements in the heap are going to be larger. + */ +class SizedHeap { + public: + struct IRModuleSHash { + IRModule mod; + size_t shash; + }; + + struct IRModuleSHashHash { + size_t operator()(const IRModuleSHash& hash) const { return hash.shash; } + }; + + struct IRModuleSHashEqual { + bool operator()(const IRModuleSHash& lhs, const IRModuleSHash& rhs) const { + return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); + } + }; + /*! + * \brief Constructor + * \param size_limit The up-limit of the heap size + */ + explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } + + /*! + * \brief Push the specific item to the heap if its key did not appears in the heap + * \param item The item to be pushed + */ + void Push(const CachedTrace& item) { + if (!in_heap.insert(IRModuleSHash{item.sch->mod(), StructuralHash()(item.sch->mod())}).second) { + return; + } + int size = heap.size(); + if (size < size_limit) { + // Heap is not full, just push + heap.emplace_back(item); + std::push_heap(heap.begin(), heap.end()); + } else if (item < heap.front()) { + // if the item is better than the worst one in the heap, we can safely kick it out + std::pop_heap(heap.begin(), heap.end()); + heap.back() = item; + std::push_heap(heap.begin(), heap.end()); + } + // Otherwise, the item is worse than any other element in the heap + } + + /*! \brief Up-limit of the heap size */ + int size_limit; + /*! \brief The heap, the worse the topper */ + std::vector heap; + /*! \brief The traces that are in the heap */ + std::unordered_set in_heap; +}; + +struct PerThreadData { + IRModule mod; + TRandState rand_state; + std::function trace_sampler; + std::function()> mutator_sampler; + + /*! \brief Default constructor. */ + PerThreadData() = default; + explicit PerThreadData(const IRModule& mod, TRandState* rand_state) + : mod(mod), rand_state(ForkSeed(rand_state)) {} + + /*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ + inline std::function()> MakeMutatorSampler( + double p_mutate, const Map& mutator_probs, + support::LinearCongruentialEngine::TRandState* rand_state) { + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - p_mutate); + double total_mass_mutator = 0.0; + if (p_mutate > 0) { + for (const auto& kv : mutator_probs) { + const Mutator& mutator = kv.first; + double mass = kv.second->value; + CHECK_GE(mass, 0.0) << "ValueError: Probability of mutator '" << mutator + << "' is ill-formed: " << mass; + total_mass_mutator += mass; + mutators.push_back(kv.first); + masses.push_back(mass * p_mutate); + } + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; + } + + /*! + * \brief Set the value for the trace and mutator samplers per thread. + * \param scores The predicted score for the given samples. + * \param p_mutate The probability of mutation. + * \param mutator_probs The probability of each mutator as a dict. + */ + void Set(const std::vector& scores, double p_mutate, + const Map& mutator_probs) { + trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + mutator_sampler = MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); + } +}; + +struct ConcurrentBitmask { + /*! The bit width. */ + static constexpr const int kBitWidth = 64; + /*! \brief The size of the concurrent bitmask. */ + int size; + /*! \brief The bitmasks. */ + std::vector bitmask; + /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ + std::vector mutexes; + + /*! + * \brief Constructor + * \param n The total slots managed by the concurrent bitmask. + */ + explicit ConcurrentBitmask(int n) + : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} + /*! + * \brief Query and mark the given index if not visted before. + * \param x The index to concurrently check if used. If not, mark as used. + * \return Whether the index has been used before. + */ + bool QueryAndMark(int x) { + std::unique_lock lock(mutexes[x / kBitWidth]); + constexpr uint64_t one = 1; + if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { + return false; + } else { + bitmask[x / kBitWidth] |= one << (x % kBitWidth); + return true; + } + } +}; + +/**************** Util Functions ****************/ + +/*! + * \brief Assemble measure candidates from the given candidate traces. + * \param traces The picked candidate traces. + * \return The assembled measure candidates. + */ +inline Array AssembleCandidates(const std::vector& picks, + const Array& args_info) { + Array measure_inputs; + measure_inputs.reserve(picks.size()); + for (const CachedTrace& pick : picks) { + measure_inputs.push_back(MeasureCandidate(pick.sch, args_info)); + } + return measure_inputs; +} + +/*! + * \brief Predict the normalized score of each candidate. + * \param candidates The candidates for prediction + * \param task The search task + * \param space The search space + * \return The normalized score in the prediction + */ +inline std::vector PredictNormalizedScore(const std::vector& cached_traces, + const TuneContext& tune_context, + const CostModel& cost_model, + const Array& args_info) { + ICHECK(cached_traces.size() > 0) + << "Candidates given for score prediction can not be empty list!"; + std::vector scores = + cost_model->Predict(tune_context, AssembleCandidates(cached_traces, args_info)); + // Normalize the score + // TODO(@junrushao1994): use softmax + temperature to replace simple normalization to [0.0, +oo) + for (double& score : scores) { + score = std::max(0.0, score); + } + return scores; +} + +/**************** Evolutionary Search ****************/ + +// TODO(@zxybazh): Early stopping for small search space, including deduplication. +/*! + * \brief A search strategy that generates measure candidates using evolutionary search. + * \note The algorithm: + * + * Loop until #measured >= total_measures: + * init = + * pick top `k = population * init_measured_ratio ` from measured + * pick `k = population * (1 - init_measured_ratio)` random selected from search space + * best = generate `population` states with the cost model, + * starting from `init`, + * using mutators, + * and return the top-n states during the search, + * where `n = num_measures_per_iter` + * chosen = pick top `k = num_measures_per_iter * (1 - eps_greedy)` from `best` + * pick `k = num_measures_per_iter * eps_greedy ` from `init` + * do the measurement on `chosen` & update the cost model + * + */ +class EvolutionarySearchNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + EvolutionarySearchNode* self; + /*! \brief The design spaces. Decisions are not used so traces only. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(EvolutionarySearchNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + /*! + * \brief Pick up best candidates from database. + * \param num The number of traces to produce. + * \return The picked best candidates. + */ + inline std::vector PickBestFromDatabase(int num); + /*! + * \brief Sample the initial population from previous measured results and randomly generated + * traces via trace replaying. + * \param num The number of traces to produce. + * \return The initial population of traces sampled. + */ + inline std::vector SampleInitPopulation(int num); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param measured Measured samples from database. + * \param unmeasured Unmeasured samples from replaying traces from design space. + * \return The merged results, excluding undefined samples. + */ + inline std::vector MergeSamples(const std::vector& measured, + const std::vector& unmeasured); + /*! + * \brief Evolve the initial population using mutators and samplers. + * \param inits The initial population of traces sampled. + * \param num The number of traces to produce. + * \return The evolved traces from initial population. + */ + inline std::vector EvolveWithCostModel(const std::vector& inits, + int num); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param inits The initial population of traces sampled. + * \param bests The best candidates predicted from evolved traces. + * \param num The number of traces to produce. + * \return The final picked candidates with a ratio of both. + */ + inline std::vector PickWithEpsGreedy(const std::vector& inits, + const std::vector& bests, + int num); + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The tuning context of the evolutionary search strategy. */ + TuneContext tune_context_{nullptr}; + /*! \brief The target for the workload. */ + Target target_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief A Database for selecting useful candidates. */ + Database database_{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model_{nullptr}; + /*! \brief The postprocessors. */ + Array postprocs_{nullptr}; + /*! \brief Mutators and their probability mass */ + Map mutator_probs_{nullptr}; + /*! \brief The number of threads to use. To be initialized with TuneContext. */ + int num_threads_; + /*! \brief The random state. To be initialized with TuneContext. */ + TRandState rand_state_; + /*! \brief Pre thread data including module to be tuned and random state. */ + std::vector per_thread_data_; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + /*! \brief The token registered for the given workload in database. */ + Workload token_{nullptr}; + + /*** Configuration: global ***/ + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*** Configuration: the initial population ***/ + /*! \brief The population size in the evolutionary search. */ + int population; + /*! \brief The ratio of measured states used in the initial population */ + double init_measured_ratio; + /*! \brief The maximum number to fail trace replaying. */ + int max_replay_fail_cnt; + + /*** Configuration: evolution ***/ + /*! \brief The number of iterations performed by generic algorithm. */ + int genetic_algo_iters; + /*! \brief The maximum number to try evolving the given trace. */ + int max_evolve_fail_cnt; + /*! \brief The probability to perform mutation */ + double p_mutate; + + /*** Configuration: pick states for measurement ***/ + /*! \brief The ratio of measurements to use randomly sampled states. */ + double eps_greedy; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `tune_context_` is not visited + // `target_` is not visited + // `args_info_` is not visited + // `database` is not visited + // `cost_model` is not visited + // `postprocs` is not visited + // `mutator_probs_` is not visited + // `num_threads` is not visited + // `rand_state_` is not visited + // `per_thread_data_` is not visited + // `state_` is not visited + + /*** Configuration: global ***/ + v->Visit("num_trials_total", &num_trials_total); + v->Visit("num_trials_per_iter", &num_trials_per_iter); + /*** Configuration: the initial population ***/ + v->Visit("population", &population); + v->Visit("init_measured_ratio", &init_measured_ratio); + v->Visit("max_replay_fail_cnt", &max_replay_fail_cnt); + /*** Configuration: evolution ***/ + v->Visit("genetic_algo_iters", &genetic_algo_iters); + v->Visit("max_evolve_fail_cnt", &max_evolve_fail_cnt); + v->Visit("p_mutate", &p_mutate); + /*** Configuration: pick states for measurement ***/ + v->Visit("eps_greedy", &eps_greedy); + } + + static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + CHECK(tune_context.defined()) << "TuneContext must be defined!"; + CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; + CHECK(tune_context->target.defined()) << "Target must be defined!"; + + this->tune_context_ = tune_context; + this->target_ = tune_context->target.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->mutator_probs_ = tune_context->mutator_probs; + this->postprocs_ = tune_context->postprocs; + this->num_threads_ = tune_context->num_threads; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->token_ = this->database_->CommitWorkload(tune_context->mod.value()); + this->per_thread_data_.reserve(this->num_threads_); + for (int i = 0; i < this->num_threads_; i++) { + this->per_thread_data_.push_back( + PerThreadData(DeepCopyIRModule(tune_context->mod.value()), &this->rand_state_)); + } + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + // Change to traces + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const tir::Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { + std::vector measured_traces; + measured_traces.reserve(num); + Array top_records = self->database_->GetTopK(self->token_, num); + for (TuningRecord record : top_records) { + measured_traces.push_back(record->trace); + } + int acutal_num = measured_traces.size(); + std::vector results(acutal_num); + auto f_proc_measured = [this, &measured_traces, &results](int thread_id, int trace_id) -> void { + TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; + const IRModule& mod = self->per_thread_data_[thread_id].mod; + tir::Trace trace = measured_traces[trace_id]; + if (Optional opt_sch = + meta_schedule::ApplyTrace(mod, trace, &rand_state, self->postprocs_)) { + tir::Schedule sch = opt_sch.value(); + results[trace_id] = CachedTrace(sch, -1.0); + } else { + LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; + throw; + } + }; + support::parallel_for_dynamic(0, acutal_num, self->num_threads_, f_proc_measured); + return results; +} + +inline std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { + // Pick unmeasured states + std::vector results(num); + auto f_proc_unmeasured = [this, &results](int thread_id, int trace_id) -> void { + TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; + const IRModule& mod = self->per_thread_data_[thread_id].mod; + CachedTrace& result = results[trace_id]; + for (int fail_ct = 0; fail_ct < self->max_replay_fail_cnt; fail_ct++) { + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]; + if (Optional opt_sch = + // replay trace, i.e., remove decisions + ApplyTrace(mod, tir::Trace(trace->insts, {}), &rand_state, self->postprocs_)) { + tir::Schedule sch = opt_sch.value(); + result = CachedTrace(sch, -1.0); + break; + } + } + if (!result.sch.defined()) { + LOG(FATAL) << "Sample-Init-Population failed over the maximum limit!"; + } + }; + support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); + return results; +} + +inline std::vector EvolutionarySearchNode::State::MergeSamples( + const std::vector& measured, const std::vector& unmeasured) { + ICHECK(measured.size() + unmeasured.size() == self->population) + << "Num of total init samples does not equal to population size!"; + std::vector inits; + inits.reserve(self->population); + inits.insert(inits.end(), measured.begin(), measured.end()); + inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); + return inits; +} + +std::vector EvolutionarySearchNode::State::EvolveWithCostModel( + const std::vector& inits, int num) { + // The heap to record best schedule, we do not consider schedules that are already measured + // Also we use `in_heap` to make sure items in the heap are de-duplicated + SizedHeap heap(num); + + // Prepare search queues + std::vector sch_curr; + std::vector sch_next; + sch_curr.reserve(self->population); + sch_next.reserve(self->population); + for (const CachedTrace& ctrace : inits) { + sch_curr.push_back(ctrace); + } + // Main loop: (genetic_algo_iters + 1) times + for (int iter = 0;; ++iter) { + // Predict normalized score with the cost model, + std::vector scores = + PredictNormalizedScore(sch_curr, self->tune_context_, self->cost_model_, self->args_info_); + for (int i = 0, n = sch_curr.size(); i < n; ++i) { + CachedTrace& entry = sch_curr[i]; + entry.score = scores[i]; + if (!self->database_->HasWorkload(entry.sch->mod())) { + heap.Push(entry); + } + } + // Discontinue once it reaches end of search + if (iter == self->genetic_algo_iters) { + break; + } + // Set threaded samplers, with probability from predicated normalized throughputs + for (int i = 0; i < self->num_threads_; ++i) { + self->per_thread_data_[i].Set(scores, self->p_mutate, self->mutator_probs_); + } + ConcurrentBitmask cbmask(scores.size()); + // The worker function + auto f_find_candidate = [&cbmask, &sch_curr, &sch_next, this](int thread_id, int trace_id) { + // Prepare samplers + TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; + const IRModule& mod = self->per_thread_data_[thread_id].mod; + const std::function& trace_sampler = self->per_thread_data_[thread_id].trace_sampler; + const std::function()>& mutator_sampler = + self->per_thread_data_[thread_id].mutator_sampler; + CachedTrace& result = sch_next[trace_id]; + // Loop until success + for (int retry_cnt = 0; retry_cnt < self->max_evolve_fail_cnt; retry_cnt++) { + int sampled_trace_id = trace_sampler(); + const CachedTrace& ctrace = sch_curr[sampled_trace_id]; + if (Optional opt_mutator = mutator_sampler()) { + // Decision: mutate + Mutator mutator = opt_mutator.value(); + if (Optional opt_new_trace = + mutator->Apply(ctrace.sch->trace().value(), &rand_state)) { + tir::Trace new_trace = opt_new_trace.value(); + if (Optional opt_sch = + ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) { + // note that sch's trace is different from new_trace + // beacuase it contains post-processing infomation + result = CachedTrace(opt_sch.value(), -1.0); + break; + } + } + } else if (cbmask.QueryAndMark(sampled_trace_id)) { + // Decision: do not mutate + result = ctrace; + break; + } + // if retry count exceeds the limit, the result should be just ctrace + if (retry_cnt + 1 == self->max_evolve_fail_cnt) { + sch_next[trace_id] = ctrace; + } + } + }; + sch_next.clear(); + sch_next.resize(self->population); + support::parallel_for_dynamic(0, self->population, self->num_threads_, f_find_candidate); + sch_curr.clear(); + sch_curr.swap(sch_next); + } + // Return the best states from the heap, sorting from higher score to lower ones + std::sort(heap.heap.begin(), heap.heap.end()); + std::vector results; + results.reserve(num); + for (const CachedTrace& item : heap.heap) { + results.push_back(item); + } + + constexpr int kNumScoresPerLine = 16; + std::ostringstream os; + int n = heap.heap.size(); + for (int st = 0; st < n; st += kNumScoresPerLine) { + os << std::endl; + int ed = std::min(st + kNumScoresPerLine, n); + os << "[" << (st + 1) << " : " << ed << "]:\t"; + for (int i = st; i < ed; ++i) { + if (i != st) { + os << " "; + } + os << std::fixed << std::setprecision(4) << heap.heap[i].score; + } + } + LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); + return results; +} + +std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( + const std::vector& unmeasured, const std::vector& bests, int num) { + int num_rands = num * self->eps_greedy; + int num_bests = num - num_rands; + std::vector rands = + tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); + std::vector results; + results.reserve(num); + for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { + bool has_best = i_bests < static_cast(bests.size()); + bool has_rand = i_rands < static_cast(rands.size()); + // Pick a schedule + CachedTrace ctrace; + // If needs `bests`, then prefer `bests` + if (i < num_bests) { + if (has_best) { + ctrace = bests[i_bests++]; + } else if (has_rand) { + ctrace = unmeasured[rands[i_rands++]]; + } else { + break; + } + } else { + // Else prefer `rands` + if (has_rand) { + ctrace = unmeasured[rands[i_rands++]]; + } else if (has_best) { + ctrace = bests[i_bests++]; + } else { + break; + } + } + results.push_back(ctrace); + } + return results; +} + +inline Optional> +EvolutionarySearchNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + int sample_num = self->num_trials_per_iter; + if (ed > self->num_trials_total) { + sample_num = self->num_trials_total - st; + ed = self->num_trials_total; + } + ICHECK_LT(st, ed); + + std::vector measured = + PickBestFromDatabase(self->population * self->init_measured_ratio); + std::vector unmeasured = SampleInitPopulation(self->population - measured.size()); + std::vector inits = MergeSamples(measured, unmeasured); + std::vector bests = EvolveWithCostModel(inits, sample_num); + std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); + return AssembleCandidates(picks, self->args_info_); +} + +inline void EvolutionarySearchNode::State::NotifyRunnerResults(const Array& results) { + st += results.size(); + ed += results.size(); + // Measure Callbacks done in TaskScheduler +} + +SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population, // + int max_replay_fail_cnt, // + double init_measured_ratio, // + int genetic_algo_iters, // + int max_evolve_fail_cnt, // + double p_mutate, // + double eps_greedy, // + Database database, // + CostModel cost_model) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + n->population = population; + n->max_replay_fail_cnt = max_replay_fail_cnt; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); + n->init_measured_ratio = init_measured_ratio; + n->genetic_algo_iters = genetic_algo_iters; + n->max_evolve_fail_cnt = max_evolve_fail_cnt; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(p_mutate, "Mutation probability"); + n->p_mutate = p_mutate; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); + n->eps_greedy = eps_greedy; + n->database_ = database; + n->cost_model_ = cost_model; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") + .set_body_typed(SearchStrategy::EvolutionarySearch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc new file mode 100644 index 0000000000..c9bc4e6b44 --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using space generator. */ +class ReplayFuncNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayFuncNode* self; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The post processors */ + Array postprocs_{nullptr}; + /*! \brief The space generator for measure candidates generation. */ + SpaceGenerator space_generator_{nullptr}; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `space_generator_` is not visited + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->space_generator_ = tune_context->space_generator.value(); + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->postprocs_ = tune_context->postprocs; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + Array result; + for (int i = st; i < ed; i++) { + for (;;) { + Array schs = self->space_generator_->GenerateDesignSpace(self->mod_); + int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); + tir::Schedule sch = schs[design_space_index]; + sch->EnterPostproc(); + bool failed = false; + for (const Postproc& proc : self->postprocs_) { + if (!proc->Apply(sch)) { + failed = true; + break; + } + } + if (!failed) { + result.push_back(MeasureCandidate(sch, self->args_info_)); + break; + } + } + } + return result; +} + +inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayFuncNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") + .set_body_typed(SearchStrategy::ReplayFunc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1c83aee8c0..d70dc7739a 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { @@ -24,20 +25,18 @@ namespace meta_schedule { /*! \brief A search strategy that generates measure candidates using trace and random decisions. */ class ReplayTraceNode : public SearchStrategyNode { public: - using TRandState = support::LinearCongruentialEngine::TRandState; - /*! \brief The state of the search strategy. */ struct State { /*! \brief The search strategy itself */ ReplayTraceNode* self; /*! \brief The design spaces. */ - Array design_spaces; + Array design_spaces; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; - explicit State(ReplayTraceNode* self, Array design_spaces) + explicit State(ReplayTraceNode* self, Array design_spaces) : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} inline Optional> GenerateMeasureCandidates(); @@ -50,9 +49,11 @@ class ReplayTraceNode : public SearchStrategyNode { int num_trials_total; /*! \brief The module to be tuned. */ - IRModule mod_{nullptr}; + Array per_thread_mod_{nullptr}; /*! \brief The metadata of the function arguments. */ Array args_info_{nullptr}; + /*! \brief The post processors */ + Array postprocs_{nullptr}; /*! \brief The number of threads to use. -1 means using logical cpu number. */ int num_threads_ = -1; /*! \brief The random state. -1 means using random number. */ @@ -63,8 +64,9 @@ class ReplayTraceNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("num_trials_total", &num_trials_total); - // `mod_` is not visited + // `per_thread_mod_` is not visited // `args_info_` is not visited + // `postprocs_` is not visited // `num_threads_` is not visited // `rand_state_` is not visited // `state_` is not visited @@ -74,9 +76,16 @@ class ReplayTraceNode : public SearchStrategyNode { TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& tune_context) final { - this->mod_ = tune_context->mod.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; this->num_threads_ = tune_context->num_threads; + + this->per_thread_mod_.reserve(this->num_threads_); + for (int i = 0; i < this->num_threads_; i++) { + this->per_thread_mod_.push_back(DeepCopyIRModule(tune_context->mod.value())); + } + + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->postprocs_ = tune_context->postprocs; this->rand_state_ = ForkSeed(&tune_context->rand_state); this->state_.reset(); } @@ -84,7 +93,12 @@ class ReplayTraceNode : public SearchStrategyNode { void PreTuning(const Array& design_spaces) final { ICHECK(!design_spaces.empty()); ICHECK(this->state_ == nullptr); - this->state_ = std::make_unique(this, design_spaces); + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const tir::Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); } void PostTuning() final { @@ -114,16 +128,16 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id, int task_id) -> void { TRandState& rand_state = per_thread_rand_state[thread_id]; - int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); - tir::Trace trace = design_spaces[design_space_index]->trace().value(); - tir::Trace new_trace = tir::Trace(trace->insts, {}); - tir::Schedule sch = tir::Schedule::Traced( // - self->mod_, // - /*rand_state=*/ForkSeed(&rand_state), // - /*debug_mode=*/0, // - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); - new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true); - per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_)); + IRModule mod = self->per_thread_mod_[thread_id]; + for (;;) { + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]; + tir::Trace new_trace = tir::Trace(trace->insts, {}); + if (Optional sch = ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) { + per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_)); + break; + } + } }; support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); return per_task_result; @@ -142,7 +156,8 @@ SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_tria } TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") + .set_body_typed(SearchStrategy::ReplayTrace); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc new file mode 100644 index 0000000000..3f68540781 --- /dev/null +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Collecting all the non-root blocks */ +class BlockCollector : public tir::StmtVisitor { + public: + static Array Collect(const tir::Schedule& sch) { // + return BlockCollector(sch).Run(); + } + + private: + /*! \brief Entry point */ + Array Run() { + for (const auto& kv : sch_->mod()->functions) { + const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function + const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function + if (const auto* func = base_func.as()) { + func_name_ = gv->name_hint; + block_names_.clear(); + blocks_to_collect_.clear(); + VisitStmt(func->body); + for (const String& block_name : blocks_to_collect_) { + results_.push_back(sch_->GetBlock(block_name, func_name_)); + } + } + } + return results_; + } + /*! \brief Constructor */ + explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {} + /*! \brief Override the Stmt visiting behaviour */ + void VisitStmt_(const tir::BlockNode* block) override { + tir::StmtVisitor::VisitStmt_(block); + CHECK(block_names_.count(block->name_hint) == 0) + << "Duplicated block name " << block->name_hint << " in function " << func_name_ + << " not supported!"; + block_names_.insert(block->name_hint); + blocks_to_collect_.push_back(block->name_hint); + } + + /*! \brief The schedule to be collected */ + const tir::Schedule& sch_; + /*! \brief The set of func name and block name pair */ + std::unordered_set block_names_; + /* \brief The list of blocks to collect in order */ + Array blocks_to_collect_; + /*! \brief Function name & blocks of collection */ + Array results_; + /*! \brief Name of the current PrimFunc */ + String func_name_; +}; + +/*! + * \brief Design Space Generator that generates design spaces by applying schedule rules to blocks + * in post-DFS order. + * */ +class PostOrderApplyNode : public SpaceGeneratorNode { + public: + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The schedule rules to be applied in order. */ + Array sch_rules_{nullptr}; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `rand_state_` is not visited + // `sch_rules_` is not visited + } + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->rand_state_ = ForkSeed(&tune_context->rand_state); + CHECK(tune_context->sch_rules.defined()) + << "ValueError: Schedules rules not given in PostOrderApply!"; + this->sch_rules_ = tune_context->sch_rules; + } + + Array GenerateDesignSpace(const IRModule& mod_) final { + using ScheduleAndUnvisitedBlocks = std::pair>; + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/mod_, // + /*rand_state=*/ForkSeed(&this->rand_state_), // + /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + + std::vector stack; + Array result{sch}; + // Enumerate the schedule rules first because you can + // always concat multiple schedule rules as one + Array all_blocks = BlockCollector::Collect(sch); + for (ScheduleRule sch_rule : sch_rules_) { + for (const tir::Schedule& sch : result) { + stack.emplace_back(sch, all_blocks); + } + result.clear(); + + while (!stack.empty()) { + // get the stack.top() + tir::Schedule sch; + Array blocks; + std::tie(sch, blocks) = stack.back(); + stack.pop_back(); + // if all blocks are visited + if (blocks.empty()) { + result.push_back(sch); + continue; + } + // otherwise, get the last block that is not visited + tir::BlockRV block_rv = blocks.back(); + blocks.pop_back(); + if (sch->HasBlock(block_rv)) { + Array applied = sch_rule->Apply(sch, /*block=*/block_rv); + for (const tir::Schedule& sch : applied) { + stack.emplace_back(sch, blocks); + } + } else { + stack.emplace_back(sch, blocks); + } + } + } + return result; + } + static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; + TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); +}; + +SpaceGenerator SpaceGenerator::PostOrderApply() { + ObjectPtr n = make_object(); + return SpaceGenerator(n); +} + +TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") + .set_body_typed(SpaceGenerator::PostOrderApply); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 3ef5026cae..ec1fd5789c 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -52,15 +52,19 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database) { +TaskScheduler TaskScheduler::RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Optional cost_model, // + Optional> measure_callbacks) { ObjectPtr n = make_object(); n->tasks = tasks; n->builder = builder; n->runner = runner; n->database = database; + n->cost_model = cost_model; + n->measure_callbacks = measure_callbacks.value_or({}); n->task_id = -1; return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 08f2b4f451..d2338d3aee 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - #include "../utils.h" namespace tvm { @@ -29,9 +28,9 @@ namespace meta_schedule { * \param candidates The measure candidates. * \return An array of the builder results. */ -Array SendToBuilder(const Builder& builder, // - const TuneContext& context, +Array SendToBuilder(const Builder& builder, const TuneContext& context, const Array& candidates) { + LOG(INFO) << "Sending " << candidates.size() << " sample(s) to builder"; Target target = context->target.value(); Array inputs; inputs.reserve(candidates.size()); @@ -45,14 +44,14 @@ Array SendToBuilder(const Builder& builder, // * \brief Send the built measure candidates to runner. * \param runner The runner to send the candidates to. * \param context The tuning context. - * \param candidates The mesure candidates. + * \param candidates The measure candidates. * \param builder_results The builder results. * \return An array of the runner results. */ -Array SendToRunner(const Runner& runner, // - const TuneContext& context, +Array SendToRunner(const Runner& runner, const TuneContext& context, const Array& candidates, const Array& builder_results) { + LOG(INFO) << "Sending " << candidates.size() << " sample(s) to runner"; Target target = context->target.value(); ICHECK_EQ(candidates.size(), builder_results.size()); int n = candidates.size(); @@ -94,54 +93,50 @@ Array SendToRunner(const Runner& runner, // void TaskSchedulerNode::InitializeTask(int task_id) { TuneContext task = this->tasks[task_id]; - // Derive the values. - IRModule mod = task->mod.value(); - SpaceGenerator space = task->space_generator.value(); - SearchStrategy strategy = task->search_strategy.value(); - // Initialize Modules. - space->InitializeWithTuneContext(task); - strategy->InitializeWithTuneContext(task); + LOG(INFO) << "Initializing task " << task_id << ": " << task->task_name << ", mod =\n" + << tir::AsTVMScript(task->mod); + this->tasks[task_id]->Initialize(); } void TaskSchedulerNode::Tune() { for (int i = 0; i < static_cast(this->tasks.size()); i++) { + TuneContext task = tasks[i]; // Check Optional value validity. - CHECK(tasks[i]->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(tasks[i]->space_generator.defined()) + CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(task->space_generator.defined()) << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(tasks[i]->search_strategy.defined()) + CHECK(task->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - InitializeTask(i); - - tasks[i]->search_strategy.value()->PreTuning( - tasks[i]->space_generator.value()->GenerateDesignSpace(tasks[i]->mod.value())); + task->search_strategy.value()->PreTuning( + task->space_generator.value()->GenerateDesignSpace(task->mod.value())); } int running_tasks = tasks.size(); - while (running_tasks > 0) { - for (int task_id; (task_id = NextTaskId()) != -1;) { - TuneContext task = tasks[task_id]; - ICHECK(!task->is_stopped); - ICHECK(!task->runner_futures.defined()); - SearchStrategy strategy = task->search_strategy.value(); - if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { - Array builder_results = - SendToBuilder(this->builder, task, task->measure_candidates.value()); - task->runner_futures = - SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); - } else { - SetTaskStopped(task_id); - --running_tasks; - } + for (int task_id; (task_id = NextTaskId()) != -1;) { + LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name; + TuneContext task = tasks[task_id]; + ICHECK(!task->is_stopped); + ICHECK(!task->runner_futures.defined()); + SearchStrategy strategy = task->search_strategy.value(); + if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { + Array builder_results = + SendToBuilder(this->builder, task, task->measure_candidates.value()); + task->builder_results = builder_results; + task->runner_futures = + SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); + } else { + SetTaskStopped(task_id); + --running_tasks; + LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; } - int n_tasks = this->tasks.size(); - for (int task_id = 0; task_id < n_tasks; ++task_id) - if (IsTaskRunning(task_id)) { - TuneContext task = tasks[task_id]; - this->JoinRunningTask(task_id); - task->search_strategy.value()->PostTuning(); - } + } + ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished"; + int n_tasks = this->tasks.size(); + for (int task_id = 0; task_id < n_tasks; ++task_id) { + ICHECK(!IsTaskRunning(task_id)) << "Task #" << task_id << " is still running"; + TuneContext task = tasks[task_id]; + task->search_strategy.value()->PostTuning(); } } @@ -176,24 +171,18 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) { results.push_back(future->Result()); } task->search_strategy.value()->NotifyRunnerResults(results); - task->runner_futures = NullOpt; - // Add to database + // Invoke the callbacks ICHECK(task->measure_candidates.defined()); - ICHECK(results.size() == task->measure_candidates.value().size()); - int index = 0; - for (const RunnerResult& result : results) { - if (!result->error_msg.defined() && result->run_secs.defined()) { - Optional trace = task->measure_candidates.value()[index]->sch->trace(); - ICHECK(trace.defined()); - this->database->CommitTuningRecord(TuningRecord( - /*trace=*/trace.value(), - /*run_secs=*/result->run_secs.value(), - /*workload=*/this->database->CommitWorkload(task->mod.value()), - /*target=*/task->target.value(), - /*args_info=*/task->measure_candidates.value()[index]->args_info)); - } - index++; + ICHECK(task->builder_results.defined()); + ICHECK_EQ(results.size(), task->measure_candidates.value().size()); + ICHECK_EQ(results.size(), task->builder_results.value().size()); + for (const MeasureCallback& callback : this->measure_callbacks) { + callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), + task->builder_results.value(), results); } + task->measure_candidates = NullOpt; + task->builder_results = NullOpt; + task->runner_futures = NullOpt; } TaskScheduler TaskScheduler::PyTaskScheduler( @@ -201,6 +190,8 @@ TaskScheduler TaskScheduler::PyTaskScheduler( Builder builder, // Runner runner, // Database database, // + Optional cost_model, // + Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // @@ -212,6 +203,12 @@ TaskScheduler TaskScheduler::PyTaskScheduler( n->builder = builder; n->runner = runner; n->database = database; + n->cost_model = cost_model; + if (measure_callbacks.defined()) { + n->measure_callbacks = measure_callbacks.value(); + } else { + n->measure_callbacks = {}; + } n->f_tune = f_tune; n->f_initialize_task = f_initialize_task; n->f_set_task_stopped = f_set_task_stopped; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 9fc9272e33..c06cb9adc8 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include "./utils.h" @@ -24,20 +23,13 @@ namespace tvm { namespace meta_schedule { -/*! - * \brief Constructor function of TuneContext class. - * \param mod The mod to be optimized. - * \param target The target to be optimized for. - * \param space_generator The design space generator. - * \param task_name The name of the tuning task. - * \param rand_state The random state. - * \param num_threads The number of threads to be used. - * \param verbose The verbosity level. - */ TuneContext::TuneContext(Optional mod, // Optional target, // Optional space_generator, // Optional search_strategy, // + Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -46,9 +38,12 @@ TuneContext::TuneContext(Optional mod, n->target = target; n->space_generator = space_generator; n->search_strategy = search_strategy; - n->task_name = task_name; + n->sch_rules = sch_rules.value_or({}); + n->postprocs = postprocs.value_or({}); + n->mutator_probs = mutator_probs.value_or({}); + n->task_name = task_name.value_or("main"); if (rand_state == -1) { - rand_state = std::random_device()(); + rand_state = support::LinearCongruentialEngine::DeviceRandom(); } support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); n->num_threads = num_threads; @@ -58,6 +53,26 @@ TuneContext::TuneContext(Optional mod, data_ = std::move(n); } +void TuneContextNode::Initialize() { + if (this->space_generator.defined()) { + this->space_generator.value()->InitializeWithTuneContext(GetRef(this)); + } + if (this->search_strategy.defined()) { + this->search_strategy.value()->InitializeWithTuneContext(GetRef(this)); + } + for (const ScheduleRule& sch_rule : sch_rules) { + sch_rule->InitializeWithTuneContext(GetRef(this)); + } + for (const Postproc& postproc : postprocs) { + postproc->InitializeWithTuneContext(GetRef(this)); + } + if (mutator_probs.defined()) { + for (const auto& kv : mutator_probs) { + kv.first->InitializeWithTuneContext(GetRef(this)); + } + } +} + TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") @@ -65,11 +80,14 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional target, // Optional space_generator, // Optional search_strategy, // + Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, task_name, rand_state, - num_threads); + return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, + mutator_probs, task_name, rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 83e65a5ced..d3b4450f19 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -20,30 +20,39 @@ #define TVM_META_SCHEDULE_UTILS_H_ #include +#include #include #include +#include #include +#include +#include +#include +#include #include +#include #include #include #include #include -#include -#include #include -#include +#include #include #include #include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" -#include "../tir/schedule/primitive.h" +#include "../support/utils.h" +#include "../tir/schedule/utils.h" namespace tvm { namespace meta_schedule { +/*! \brief The type of the random state */ +using TRandState = support::LinearCongruentialEngine::TRandState; + /*! * \brief Read lines from a json file. * \param path The path to the json file. @@ -193,7 +202,7 @@ inline support::LinearCongruentialEngine::TRandState ForkSeed( /*! * \brief Fork a random state into another ones, i.e. PRNG splitting. - * The given random state is also mutated. + * The given random state is also mutated. * \param rand_state The random state to be forked * \param n The number of forks * \return The forked random states @@ -208,6 +217,92 @@ inline std::vector ForkSeed( return results; } +/*! + * \brief Get deep copy of an IRModule. + * \param mod The IRModule to make a deep copy. + * \return The deep copy of the IRModule. + */ +inline IRModule DeepCopyIRModule(IRModule mod) { + return Downcast(LoadJSON(SaveJSON(mod))); +} + +/*! + * \brief Concatenate strings + * \param strs The strings to concatenate + * \param delim The delimiter + * \return The concatenated string + */ +inline std::string Concat(const Array& strs, const std::string& delim) { + if (strs.empty()) { + return ""; + } + std::ostringstream os; + os << strs[0]; + for (int i = 1, n = strs.size(); i < n; ++i) { + os << delim << strs[i]; + } + return os.str(); +} + +/*! + * \brief Get the BlockRV from a block StmtSRef + * \param sch The schedule + * \param block_sref The block StmtSRef + * \param global_var_name The global variable name + * \return The BlockRV + */ +inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, + const String& global_var_name) { + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return sch->GetBlock(block->name_hint, global_var_name); +} + +/*! + * \brief Get the number of cores in CPU + * \param target The target + * \return The number of cores. + */ +inline int GetTargetNumCores(const Target& target) { + int num_cores = target->GetAttr("num-cores").value_or(-1); + if (num_cores == -1) { + static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count"); + ICHECK(f_cpu_count) + << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\""; + num_cores = (*f_cpu_count)(false); + LOG(FATAL) + << "Target does not have attribute \"num-cores\", pyhsical core number must be " + "defined! For example, on the local machine, the target must be \"llvm -num-cores " + << num_cores << "\""; + } + return num_cores; +} + +/*! + * \brief Apply the trace and postprocessors to an IRModule + * \param mod The IRModule to be applied + * \param trace The trace to apply to the IRModule + * \param rand_state The random seed + * \param postprocs The postprocessors to apply to the IRModule + * \return The schedule created, or NullOpt if any postprocessor fails + */ +inline Optional ApplyTrace(const IRModule& mod, const tir::Trace& trace, + TRandState* rand_state, + const Array& postprocs) { + tir::Schedule sch = + tir::Schedule::Traced(mod, + /*rand_state=*/ForkSeed(rand_state), + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + sch->EnterPostproc(); + for (const Postproc& proc : postprocs) { + if (!proc->Apply(sch)) { + return NullOpt; + } + } + return sch; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index ebd667ae2a..2f3b6db6f3 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -409,7 +409,7 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintBody(const Stmt& body, bool indent = true); }; -String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); +String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool show_meta = false); String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate); diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index b5780975f4..44dc323186 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -79,7 +79,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& void InitContextFunctions(std::function fgetsymbol); /*! - * \brief Type alias for funcion to wrap a TVMBackendPackedCFunc. + * \brief Type alias for function to wrap a TVMBackendPackedCFunc. * \param The function address imported from a module. * \param mptr The module pointer node. * \return Packed function that wraps the invocation of the function at faddr. diff --git a/src/support/array.h b/src/support/array.h index 95b4f58a2e..218150f9db 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -100,6 +100,29 @@ inline Array AsArray(const ShapeTuple& shape) { return result; } +/*! + * \brief Concatenate a list of arrays into a single array + * \tparam T The type of elements in the arrays + * \tparam Iterator The type of the iterator into the list of arrays + * \param begin The begin iterator to the array list + * \param end The end iterator to the array list + * \return The concatenated array + */ +template +inline Array ConcatArrayList(Iterator begin, Iterator end) { + int size = 0; + for (Iterator it = begin; it != end; ++it) { + size += (*it).size(); + } + Array result; + result.reserve(size); + for (Iterator it = begin; it != end; ++it) { + const auto& item = *it; + result.insert(result.end(), item.begin(), item.end()); + } + return result; +} + /********** Implementation details of AsVector **********/ namespace details { diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index ae4a0386d4..46bbd2bceb 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -144,6 +144,29 @@ inline NDIntSet NDIntSetEval( return ret; } +/*! + * \brief Output the N-dimensional integer set to a stream. + * \param os The output stream. + * \param nd_int_set The N-dimensional integer set to be output. + * \return The output stream. + */ +inline std::ostream& operator<<(std::ostream& os, const NDIntSet& nd_int_set) { + os << '['; + bool is_first = true; + for (const arith::IntSet& int_set : nd_int_set) { + if (is_first) { + is_first = false; + } else { + os << ", "; + } + PrimExpr min = int_set.min(); + PrimExpr max = int_set.max(); + os << min << ":" << max; + } + os << ']'; + return os; +} + } // namespace support } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 9f7bc56b85..2be6f5035a 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -227,6 +227,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mabi") .add_attr_option("system-lib") .add_attr_option("runtime") + .add_attr_option("num-cores") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index dc1ed1c193..d01788e92c 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -92,7 +92,7 @@ class GPUCodeVerifier : public StmtExprVisitor { const auto* extent = op->value.as(); ICHECK(extent); - std::string name = var.get()->name_hint; + std::string name = op->node.as()->thread_tag; // record the number of threads in a block if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" || name == "vthread") { @@ -151,6 +151,7 @@ class GPUCodeVerifier : public StmtExprVisitor { errors_.push_back(s.str()); } }; + err("threads per block", thread_per_block_, max_threads_per_block_); err("local memory per block", local_memory_per_block_, max_local_memory_per_block_); err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 101d80a52e..e02e474dab 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -24,6 +24,7 @@ #include #include #include +#include namespace tvm { namespace tir { @@ -64,6 +65,129 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); +Array IndexMapNode::Apply(const Array& inputs) const { + CHECK_EQ(inputs.size(), this->src_iters.size()); + int n = inputs.size(); + std::unordered_map var_map; + var_map.reserve(n); + for (int i = 0; i < n; ++i) { + var_map.emplace(this->src_iters[i].get(), inputs[i]); + } + Array results; + results.reserve(this->tgt_iters.size()); + for (PrimExpr result : this->tgt_iters) { + results.push_back(Substitute(std::move(result), var_map)); + } + return results; +} + +IndexMap::IndexMap(Array src_iters, Array tgt_iters) { + ObjectPtr n = make_object(); + n->src_iters = std::move(src_iters); + n->tgt_iters = std::move(tgt_iters); + data_ = std::move(n); +} + +IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func) { + Array src_iters; + src_iters.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + src_iters.push_back(Var("i" + std::to_string(i), DataType::Int(32))); + } + return IndexMap(src_iters, func(src_iters)); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + const auto* n = node.as(); + ICHECK(n); + p->stream << "IndexMap: ("; + for (int i = 0, total = n->src_iters.size(); i < total; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << n->src_iters[i]; + } + p->stream << ") => "; + p->stream << "("; + for (int i = 0, total = n->tgt_iters.size(); i < total; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << n->tgt_iters[i]; + } + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(IndexMapNode); +TVM_REGISTER_GLOBAL("tir.IndexMap") + .set_body_typed([](Array src_iters, Array tgt_iters) { + return IndexMap(src_iters, tgt_iters); + }); +TVM_REGISTER_GLOBAL("tir.IndexMapFromFunc").set_body_typed(IndexMap::FromFunc); +TVM_REGISTER_GLOBAL("tir.IndexMapApply").set_body_method(&IndexMapNode::Apply); + +TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) { + // check the number of func var is equal + CHECK_EQ(desc_func->params.size(), intrin_func->params.size()); + CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size()); + + // check both functions' bodies are directly block + const auto* desc_realize = + Downcast(desc_func->body)->block->body.as(); + const auto* intrin_realize = + Downcast(intrin_func->body)->block->body.as(); + CHECK(desc_realize != nullptr) << "description function's body expect a directly block"; + CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block"; + + const Block& desc_block = desc_realize->block; + const Block& intrin_block = intrin_realize->block; + + // check block var number and iter type + CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size()) + << "Two blocks should have the same number of block vars"; + for (size_t i = 0; i < desc_block->iter_vars.size(); i++) { + const IterVar& desc_var = desc_block->iter_vars[i]; + const IterVar& intrin_var = intrin_block->iter_vars[i]; + CHECK(desc_var->iter_type == intrin_var->iter_type) + << "Block iter_type mismatch between " << desc_var->iter_type << " and " + << intrin_var->iter_type; + } + + auto n = make_object(); + n->description = std::move(desc_func); + n->implementation = std::move(intrin_func); + data_ = std::move(n); +} + +class TensorIntrinManager { + public: + Map reg; + + static TensorIntrinManager* Global() { + static TensorIntrinManager* inst = new TensorIntrinManager(); + return inst; + } +}; + +TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc intrin_func) { + TensorIntrinManager* manager = TensorIntrinManager::Global(); + ICHECK_EQ(manager->reg.count(name), 0) + << "ValueError: TensorIntrin '" << name << "' has already been registered"; + TensorIntrin intrin(desc_func, intrin_func); + manager->reg.Set(name, intrin); + return intrin; +} + +TensorIntrin TensorIntrin::Get(String name) { + const TensorIntrinManager* manager = TensorIntrinManager::Global(); + ICHECK_EQ(manager->reg.count(name), 1) + << "ValueError: TensorIntrin '" << name << "' is not registered"; + return manager->reg.at(name); +} + +TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { // TODO(tvm-team) redirect to Text printer once we have a good text format. @@ -85,5 +209,13 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); +TVM_REGISTER_GLOBAL("tir.TensorIntrin") + .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { + return TensorIntrin(desc_func, intrin_func); + }); + +TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); +TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 42e0e00995..bc4b01915a 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_ANALYSIS_H_ #include +#include #include #include @@ -69,6 +70,26 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl */ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); +/*! + * \brief The information of a block scope, including the leaf blocks, + * as well as the loop types (spatial, reduction) for each loop in the scope. + */ +struct ScopeBlockLoopInfo { + /*! \brief A list of the leaf blocks, from left to right */ + std::vector realizes; + /*! \brief The loop vars bound to spatial block iters */ + std::unordered_set spatial_vars; + /*! \brief The loop vars bound to non-spatial block iters */ + std::unordered_set non_spatial_vars; +}; + +/*! + * \brief Inspect the scope of the given sref + * \param scope_block The root block of the scope + * \return The information of the scope + */ +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); + /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it @@ -174,6 +195,27 @@ bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check if the block is a data parallel block, i.e. all the block vars are data parallel + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a data parallel block + */ +bool IsSpatial(const StmtSRef& block_sref); + +/*! + * \brief Extracts the types of the block vars + * \param block_sref The block to be checked + * \return A vector of types of the block vars + */ +std::vector GetBlockVarTypes(const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be considered as a "write cache" + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a write cache + */ +bool IsWriteCache(const StmtSRef& block_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -195,6 +237,15 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va */ void CheckAffineBinding(const ScheduleState& self, Block block); +/*! + * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, + * from outer to inner. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block has a trivial binding + */ +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path @@ -266,6 +317,36 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to + * \param loop_sref The loop to be checked + * \return The IterVarType of the specific loop + */ +IterVarType GetLoopIterType(const StmtSRef& loop_sref); + +/*! + * \brief Check whether the loop/block has only one child + * \param loop_or_block_sref The loop/block to be checked + * \return Whether the loop/block has only one child + */ +bool HasSingleChild(const StmtSRef& loop_or_block_sref); + +/*! + * \brief Check if a block is the direct children of the root block + * \param self The TIR schedule class + * \param block_sref The block to be analyzed + * \return A boolean flag indicating if the block is the subroot block + */ +bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref); + +/*! + * \brief Collect all the feasible compute locations among the loops above the block + * \param self The TIR schedule class + * \param block_sref The input block + * \return All the feasible compute locations among the loops above the block + */ +Array CollectComputeLocation(const ScheduleState& self, const StmtSRef& block_sref); + /******** Producer-consumer relation ********/ /*! @@ -393,6 +474,88 @@ std::vector> GetReducerGetters() bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); +/******** Misc ********/ + +/*! + * \brief Given the read/write region, extract the pattern of their index correspondence + * namely, the mapping from read index to the write index. + * \param read_region The read region + * \param write_region The write region + * \return A tuple of booleans, the extracted pattern + * 0) exists: if the pattern is found + * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once + * e.g. A[i, j] = B[i, i, j] + * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once. + * e.g. A[i, j] = B[i] + * 3) ordered: if the mapping is ordered + * 4) no_const_read: if there is no constant indexing in the read indices, + * e.g. A[i, j] = B[0, i, j] + * 5) no_shift_read: if there is no constant shift in the read indices, + * e.g. A[i, j] = B[i + 1, j] + */ +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); + +/*! + * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is + * beneficial. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has data reuse opportunity + */ +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the given AST contains the specific operators + * \param stmt The AST to be checked + * \param ops The list of operators to be checked + * \return A boolean indicating whether the AST contains the specific operators + */ +bool HasOp(const Stmt& stmt, const Array& ops); + +/*! + * \brief Checks if the given AST contains if-then-else, including + * 1) IfThenElse statement + * 2) Select expression + * 3) The operator `tir.if_then_else` + * 4) Block predicates + */ +bool HasIfThenElse(const Stmt& stmt); + +/*! + * \brief Checks if a block could be successfully computed inline into its consumer + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block could be successfully computed inline + */ +bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be successfully computed inline into its producer + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block could be successfully computed inline + */ +bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Provided the access pattern to a buffer, suggest one of the possible layout + * transformation to minimize the locality of the access pattern. + * \param buffer The buffer to be transformed + * \param indices The access pattern to the buffer + * \param loops The loops above the buffer + * \param predicate The predicate of the access + * \param analyzer Arithmetic analyzer + */ +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7e16bc92e4..5ee363d107 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -47,6 +49,37 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl /******** Scope ********/ +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { + struct Collector : public StmtVisitor { + void VisitStmt_(const BlockRealizeNode* realize) final { + result.realizes.push_back(GetRef(realize)); + const Array& iter_vars = realize->block->iter_vars; + const Array& iter_values = realize->iter_values; + ICHECK_EQ(iter_vars.size(), iter_values.size()); + int n = realize->iter_values.size(); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = iter_vars[i]; + const PrimExpr& iter_value = iter_values[i]; + std::unordered_set* vars = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + vars = &result.spatial_vars; + } else { + vars = &result.non_spatial_vars; + } + PostOrderVisit(iter_value, [vars](const ObjectRef& obj) { + if (const VarNode* var = obj.as()) { + vars->insert(var); + } + }); + } + } + + ScopeBlockLoopInfo result; + } visitor; + visitor(scope_block->body); + return std::move(visitor.result); +} + StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, // bool require_stage_pipeline, // bool require_subtree_compact_dataflow) { @@ -408,6 +441,43 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } +bool IsSpatial(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != IterVarType::kDataPar) { + return false; + } + } + return true; +} + +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + std::vector results; + results.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + results.push_back(iter_var->iter_type); + } + return results; +} + +bool IsWriteCache(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1) { + return false; + } + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool exists, surjective, injective, ordered, no_const_read, no_shift_read; + std::tie(exists, surjective, injective, ordered, no_const_read, no_shift_read) = + AnalyzeReadWritePattern(read_region, write_region); + if (!(injective && ordered)) { + return false; + } + } + return true; +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, @@ -455,6 +525,22 @@ void CheckAffineBinding(const ScheduleState& self, Block block) { } } +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = GetLoops(block_sref); + Array binds = GetBlockRealize(self, block_sref)->iter_values; + if (loops.size() != binds.size()) { + return false; + } + for (int i = 0, n = loops.size(); i < n; ++i) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + if (binds[i].get() != loop->loop_var.get()) { + return false; + } + } + return true; +} + Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { @@ -644,6 +730,95 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +IterVarType GetLoopIterType(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const Var& loop_var = loop->loop_var; + int n_spatial = 0; + int n_reduce = 0; + int n_other = 0; + auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { + if (const auto* realize = obj.as()) { + const BlockNode* block = realize->block.get(); + // Number of block vars and their bindings + ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); + int n = realize->iter_values.size(); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + // Categorize the current block var + int* ref = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + ref = &n_spatial; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + ref = &n_reduce; + } else { + ref = &n_other; + } + // Visit the binding to see if `loop_var` appears + PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void { + if (obj.same_as(loop_var)) { + (*ref) += 1; + } + }); + } + return false; + } + return true; + }; + PreOrderVisit(loop->body, f_visit); + if (n_other) { + return IterVarType::kOpaque; + } else if (n_spatial && n_reduce) { + return IterVarType::kOpaque; + } else if (n_reduce) { + return IterVarType::kCommReduce; + } else { + return IterVarType::kDataPar; + } +} + +bool HasSingleChild(const StmtSRef& loop_or_block_sref) { + const StmtNode* body = nullptr; + if (const auto* loop = loop_or_block_sref->StmtAs()) { + body = loop->body.get(); + } else if (const auto* block = loop_or_block_sref->StmtAs()) { + body = block->body.get(); + } else { + LOG(FATAL) << "TypeError: Unable to recognize the type of `loop_or_block_sref`: " + << loop_or_block_sref->stmt->GetTypeKey(); + } + if (body->IsInstance()) { + const auto* seq_stmt = static_cast(body); + return seq_stmt->seq.size() == 1; + } + return true; +} + +bool IsSubrootBlock(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { + tir::StmtSRef parent_block_sref = GetScopeRoot(self, block_sref, false, false); + return parent_block_sref->parent == nullptr; +} + +Array CollectComputeLocation(const ScheduleState& self, const StmtSRef& block_sref) { + Array loop_srefs = GetLoops(block_sref); + Array result; + result.reserve(loop_srefs.size()); + bool visited_reduce = false; + for (const StmtSRef& loop_sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + IterVarType iter_type = GetLoopIterType(loop_sref); + if (iter_type == IterVarType::kDataPar) { + if (visited_reduce) { + break; + } + } else { + visited_reduce = true; + } + result.push_back(loop_sref); + } + return result; +} + /******** Producer-consumer relation ********/ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { @@ -1343,5 +1518,190 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +/******** Misc ********/ + +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region) { + static constexpr const std::tuple kNotExist = { + false, false, false, false, false, false}; + // Step 1. Extract the write indices + int w_dim = write_region->buffer->shape.size(); + std::unordered_map var2idx; + var2idx.reserve(w_dim); + for (int i = 0; i < w_dim; ++i) { + const Range& dom = write_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + if (const auto* v = dom->min.as()) { + var2idx.emplace(v, i); + } else { + return kNotExist; + } + } + // Step 2. Map each read index to a write index + bool no_const_read = true; + bool no_shift_read = true; + int r_dim = read_region->buffer->shape.size(); + std::vector mapped(r_dim, -1); + for (int i = 0; i < r_dim; ++i) { + const Range& dom = read_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + // Case 1. Read index is a constant + if (as_const_int(dom->min) != nullptr) { + no_const_read = false; + continue; + } + // Case 2. Read index cannot be recognized as `var +/- const` + // where `var` is a write index and `const` is an optional constant shift + Optional opt_const = NullOpt; + const VarNode* var = + static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); + if (var == nullptr || !var2idx.count(var)) { + return kNotExist; + } + // Case 3. Read index is `var +/- const` + mapped[i] = var2idx.at(var); + if (opt_const.defined()) { + no_shift_read = false; + } + } + // Step 3. Check if the mapping is ordered, and count how many times each var is mapped + std::vector mapped_counter(w_dim, 0); + bool ordered = true; + int last_mapped = -1; + for (int i : mapped) { + if (i != -1) { + ++mapped_counter[i]; + if (last_mapped != -1 && last_mapped > i) { + ordered = false; + } + last_mapped = i; + } + } + // Step 4. Check if the mapping is surjective or injective + // Surjective: each write index is mapped at least once + // Injective: each write index is mapped at most once + bool surjective = true; + bool injective = true; + for (int cnt : mapped_counter) { + if (cnt == 0) { + surjective = false; + } else if (cnt >= 2) { + injective = false; + } + } + return {/*exist=*/true, surjective, injective, ordered, no_const_read, no_shift_read}; +} + +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || + !IsTrivialBinding(self, block_sref)) { + return false; + } + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + // Step 1. Sort out spatial block variables + std::vector spatial_block_vars; + spatial_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& block_var : block->iter_vars) { + if (block_var->iter_type == IterVarType::kDataPar) { + spatial_block_vars.push_back(block_var->var.get()); + } + } + // Step 2. Enumerate each read region, check the number of block vars that are not used + // to index the read region + int total_unused_block_vars = 0; + std::unordered_set read_buffers; + read_buffers.reserve(block->reads.size()); + for (const BufferRegion& buffer_region : block->reads) { + const BufferNode* buffer = buffer_region->buffer.get(); + const Array& regions = buffer_region->region; + // Step 2.1. Duplication of read buffers are not allowed + if (read_buffers.insert(buffer).second == false) { + return false; + } + // Step 2.2. Skip the reduction buffer + if (buffer == write_buffer) { + continue; + } + // Step 2.3. Collect the block vars that are used to index the read region + std::unordered_set vars; + for (const Range& range : regions) { + if (as_const_int(range->extent) == nullptr) { + return false; + } + for (const Var& var : UndefinedVars(range->min)) { + vars.insert(var.get()); + } + } + // Step 2.4. Check if the block vars are not used to index the read region + int n_unused_block_vars = 0; + for (const VarNode* block_var : spatial_block_vars) { + if (vars.count(block_var) == 0) { + ++n_unused_block_vars; + } + } + total_unused_block_vars += n_unused_block_vars; + } + return total_unused_block_vars >= 1; +} + +bool HasOp(const Stmt& stmt, const Array& ops) { + std::unordered_set op_set; + op_set.reserve(ops.size()); + for (const Op& op : ops) { + op_set.insert(op.operator->()); + } + bool found = false; + PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { + if (found) { + return false; + } + if (const auto* call = obj.as()) { + if (op_set.count(call->op.operator->())) { + found = true; + } + } + return !found; + }); + return found; +} + +bool HasIfThenElse(const Stmt& stmt) { + bool has_branch = false; + auto f_visit = [&has_branch](const ObjectRef& obj) -> bool { + if (has_branch) { + // stop visiting + return false; + } + if (const auto* realize = obj.as()) { + // Case 1: BlockRealize + if (!is_one(realize->predicate)) { + has_branch = true; + } + } else if (obj->IsInstance() || obj->IsInstance()) { + // Case 2: IfThenElse / Select + has_branch = true; + } else if (const auto* call = obj.as()) { + // Case 3: Call + static const Op& op_if_then_else = Op::Get("tir.if_then_else"); + if (call->op.same_as(op_if_then_else)) { + has_branch = true; + } + } + return !has_branch; + }; + PreOrderVisit(stmt, f_visit); + return has_branch; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc new file mode 100644 index 0000000000..993557f8be --- /dev/null +++ b/src/tir/schedule/analysis/layout.cc @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Calculate the strides of the buffer + * \param buffer The buffer + * \return The strides + */ +Array GetStrides(const Buffer& buffer) { + if (!buffer->strides.empty()) { + ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + return buffer->strides; + } + int ndim = buffer->shape.size(); + if (ndim == 0) { + return {}; + } + Array strides(ndim, PrimExpr{nullptr}); + PrimExpr stride = make_const(buffer->DefaultIndexType(), 1); + for (int i = ndim - 1; i >= 0; --i) { + strides.Set(i, stride); + stride = stride * buffer->shape[i]; + } + return strides; +} + +/*! + * \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern + * to help decision making in layout transformation + */ +class SplitExprCollector { + public: + /*! + * \brief The corresponding IterSplitExpr, simplified for our case + * The pattern is `source // lower_factor % extent * scale` + */ + struct SplitExpr { + /*! \brief The source variable */ + Var source; + /*! \brief The lower factor of the split expression */ + int64_t lower_factor; + /*! \brief The extent of the split expression */ + int64_t extent; + }; + + /*! + * \brief Collect the split expressions in the indexing pattern + * \param index The indexing pattern + * \param input_iters The input iterators' domain + * \param predicate The predicate of the affine map + * \param require_bijective Whether the affine map is required to be bijective + * \param analyzer The analyzer + * \return The collected split expressions + */ + static std::vector Collect(const PrimExpr& index, + const Map& input_iters, // + const PrimExpr& predicate, // + bool require_bijective, // + arith::Analyzer* analyzer) { + Array iter_sum_exprs = arith::DetectIterMap( + {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer); + if (iter_sum_exprs.empty()) { + return {}; + } + ICHECK_EQ(iter_sum_exprs.size(), 1); + if (iter_sum_exprs[0]->args.size() == 0) { + return {}; + } + SplitExprCollector collector; + collector.Visit(iter_sum_exprs[0]); + if (collector.failed_) { + return {}; + } + return std::move(collector.exprs_); + } + + private: + void Visit(const arith::IterSplitExpr& expr) { + if (const auto* var = expr->source->source.as()) { + const int64_t* lower_factor = as_const_int(expr->lower_factor); + const int64_t* extent = as_const_int(expr->extent); + if (lower_factor == nullptr || extent == nullptr) { + failed_ = true; + return; + } + exprs_.push_back(SplitExpr{GetRef(var), *lower_factor, *extent}); + } else if (const auto* iter_sum_expr = expr->source->source.as()) { + Visit(GetRef(iter_sum_expr)); + } else { + ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); + } + } + + void Visit(const arith::IterSumExpr& expr) { + for (const arith::IterSplitExpr& arg : expr->args) { + Visit(arg); + } + } + + /*! \brief Whether the analysis failed */ + bool failed_ = false; + /*! \brief The collected split expressions */ + std::vector exprs_; +}; + +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer) { + int ndim = buffer->shape.size(); + int n_loops = loops.size(); + // Step 1. Collect the domains and indices of loop variables + Map input_iters; + std::unordered_map var2id; + var2id.reserve(n_loops); + for (int i = 0; i < n_loops; ++i) { + const For& loop = loops[i]; + input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + var2id.emplace(loop->loop_var.get(), i); + } + // Step 2. Calculate a functor that flattens a multi-dimensional index + auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()]( + const Array& indices) -> PrimExpr { + PrimExpr flatten_index = make_const(dtype, 0); + for (int i = 0; i < ndim; ++i) { + flatten_index = flatten_index + strides[i] * indices[i]; + } + return flatten_index; + }; + // Step 3. Detect the IterSplitExpr of the indexing pattern + std::vector split_exprs = SplitExprCollector::Collect( + /*index=*/f_flatten_index(indices), input_iters, predicate, + /*require_bijective=*/false, analyzer); + if (split_exprs.empty()) { + return NullOpt; + } + // Step 4. Sort the order of the split expressions + std::vector order(split_exprs.size(), 0); + std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; }); + std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool { + const SplitExprCollector::SplitExpr& a = split_exprs[_a]; + const SplitExprCollector::SplitExpr& b = split_exprs[_b]; + int a_var_id = var2id.at(a.source.get()); + int b_var_id = var2id.at(b.source.get()); + if (a_var_id != b_var_id) { + return a_var_id < b_var_id; + } + return a.lower_factor > b.lower_factor; + }); + // Step 5. Create the indexing mapping + auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), // + split_exprs = std::move(split_exprs), // + order = std::move(order), // + shape = buffer->shape, // + analyzer // + ](Array indices) -> Array { + ICHECK_EQ(indices.size(), shape.size()); + for (int i = 0, n = indices.size(); i < n; ++i) { + analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); + } + PrimExpr index = f_flatten_index({indices.begin(), indices.end()}); + int ndim = split_exprs.size(); + // Step 5.1. Split the flattened index according to `split_exprs` + std::vector split; + split.reserve(ndim); + for (int i = ndim - 1; i >= 0; --i) { + index = analyzer->Simplify(index); + int64_t extent = split_exprs[i].extent; + split.push_back(analyzer->Simplify(floormod(index, extent))); + index = floordiv(index, extent); + } + std::reverse(split.begin(), split.end()); + // Step 5.2. Reorder the indexing pattern according to `order` + Array results; + results.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + results.push_back(split[order[i]]); + } + return results; + }; + return IndexMap::FromFunc(ndim, f_alter_layout); +} + +TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") + .set_body_typed([](Buffer buffer, Array indices, Array loops, + PrimExpr predicate) { + arith::Analyzer analyzer; + return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 4db4cd4ba1..a8bf31e9d1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -18,8 +18,6 @@ */ #include "./concrete_schedule.h" -#include - namespace tvm { namespace tir { @@ -30,7 +28,7 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } @@ -214,7 +212,7 @@ Schedule ConcreteScheduleNode::Copy() const { void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { if (seed == -1) { - seed = std::random_device()(); + seed = support::LinearCongruentialEngine::DeviceRandom(); } support::LinearCongruentialEngine(&rand_state_).Seed(seed); } @@ -242,6 +240,15 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int throw; } +LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV( + tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); + TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -477,6 +484,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +/******** Schedule: Data movement ********/ + +BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -562,7 +593,81 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { } /******** Schedule: Blockize & Tensorize ********/ -/******** Schedule: Annotation ********/ +BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Blockize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); + return CreateRV(result); +} + +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), intrin); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin_name)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, + const ObjectRef& ann_val) { + TVM_TIR_SCHEDULE_BEGIN(); + if (const auto* str = ann_val.as()) { + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, GetRef(str)); + } else if (const auto* expr = ann_val.as()) { + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->Get(GetRef(expr))); + } else { + LOG(FATAL) + << "TypeError: Only strings, integers, floats and ExprRVs are supported for now, but gets: " + << ann_val->GetTypeKey(); + throw; + } + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); +} + +void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); +} + +void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, + const ObjectRef& ann_val) { + TVM_TIR_SCHEDULE_BEGIN(); + if (const auto* str = ann_val.as()) { + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, GetRef(str)); + } else if (const auto* expr = ann_val.as()) { + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, this->Get(GetRef(expr))); + } else { + LOG(FATAL) + << "TypeError: Only strings, integers, floats and ExprRVs are supported for now, but gets: " + << ann_val->GetTypeKey(); + throw; + } + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); +} + +void ConcreteScheduleNode::Unannotate(const BlockRV& loop_rv, const String& ann_key) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); +} + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 035c16f506..b464e32fde 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -72,6 +72,7 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline bool HasBlock(const BlockRV& block_rv) const final; inline Array GetSRefs(const Array& rvs) const; inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } @@ -85,6 +86,8 @@ class ConcreteScheduleNode : public ScheduleNode { Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; + LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; @@ -106,6 +109,11 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -119,7 +127,15 @@ class ConcreteScheduleNode : public ScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) override; + void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) override; + void Tensorize(const LoopRV& loop_rv, const String& intrin_name) override; + /******** Schedule: Annotation ********/ + void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; + void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const BlockRV& loop_rv, const String& ann_key) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} @@ -192,6 +208,19 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { return this->analyzer_->Simplify(transformed); } +inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { + auto it = this->symbol_table_.find(block_rv); + if (it == this->symbol_table_.end()) { + return false; + } + const ObjectRef& obj = (*it).second; + const auto* sref = obj.as(); + if (sref == nullptr || sref->stmt == nullptr) { + return false; + } + return true; +} + inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index af721767c3..cedba4b960 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -21,6 +21,11 @@ namespace tvm { namespace tir { +bool InstructionKindNode::IsPostproc() const { + static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); + return this == inst_enter_postproc.get(); +} + Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, Array outputs) { ObjectPtr n = make_object(); diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 95d636467a..f842f75763 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -43,7 +43,7 @@ namespace tir { * * // Convertible to `InstructionKindNode::FInstructionApply` * static Array ApplyToSchedule( - * const tir::Schedule& sch, + * const Schedule& sch, * const Array& inputs, * const Array& attrs, * const Optional& decision); diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index cc7e44d4df..0cf33b2fc3 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -22,20 +22,40 @@ #include #include +#include #include namespace tvm { namespace tir { +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +TVM_DLL std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); + /******** Schedule: Sampling ********/ /*! * \brief Sample a random integer from a given range. + * \param rand_state The pointer to schedule's random state * \param min_inclusive The minimum value of the range, inclusive. * \param max_exclusive The maximum value of the range, exclusive. * \return The random integer sampled in the given range. */ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive); +/*! + * \brief Sample k random integers from given range without replacement, i.e, no duplication. + * \param rand_state The pointer to schedule's random state + * \param n The range is defined as 0 to n-1. + * \param k The total number of samples. + * \return The randomly selected samples from the n candidates. + */ +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); /*! * \brief Sample once category from candidates according to the probability weights. * \param rand_state The pointer to schedule's random state @@ -47,6 +67,14 @@ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_st TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +TVM_DLL std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); /*! * \brief Sample the factors to perfect tile a specific loop * \param rand_state The random state @@ -81,6 +109,16 @@ TVM_DLL std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, Optional>* decision); +/*! + * \brief Sample a compute-at location on a BlockRV so that its producer can compute at that loop + * \param self The schedule state + * \param rand_state The random state + * \param block_rv The consumer block to be computed at + * \return The sampled loop to be computed at + */ +TVM_DLL tir::StmtSRef SampleComputeLocation( + tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, + const tir::StmtSRef& block_sref, Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! @@ -224,6 +262,15 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); + +/******** Schedule: Data movement ********/ + +TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); + +TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope); + /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -339,7 +386,28 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu int axis, int factor, int offset); /******** Schedule: Blockize & Tensorize ********/ + +TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); +TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& loop_sref, + const TensorIntrin& intrinsic); + /******** Schedule: Annotation ********/ +/*! + * \brief Annotate a block/loop with a key value pair + * \param self The state of the schedule + * \param sref The block/loop sref to be annotated + * \param ann_key The annotation key + * \param ann_val The annotation value + */ +TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, + const ObjectRef& ann_val); +/*! + * \brief Unannotate a block/loop's annotation with key ann_key + * \param self The state of the schedule + * \param sref The block/loop to be unannotated + * \param ann_key The annotation key + */ +TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc new file mode 100644 index 0000000000..09b7a47e8e --- /dev/null +++ b/src/tir/schedule/primitive/annotate.cc @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, + const ObjectRef& ann_val) { + // Extract annotation + const Map* annotations = nullptr; + if (const auto* loop = sref->StmtAs()) { + annotations = &loop->annotations; + } else if (const auto* block = sref->StmtAs()) { + annotations = &block->annotations; + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + } + // Check if the annotation already exists + if (annotations->find(ann_key) != annotations->end()) { + return; + } + // Add the new annotation + Map new_ann(*annotations); + new_ann.Set(ann_key, ann_val); + // Create the new stmt + if (const auto* loop = sref->StmtAs()) { + ObjectPtr n = make_object(*loop); + n->annotations = std::move(new_ann); + self->Replace(sref, For(n), {}); + } else if (const auto* block = sref->StmtAs()) { + ObjectPtr n = make_object(*block); + n->annotations = std::move(new_ann); + Block p(n); + self->Replace(sref, p, {{GetRef(block), p}}); + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + throw; + } +} + +void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) { + // Extract annotation + const Map* annotations = nullptr; + if (const auto* loop = sref->StmtAs()) { + annotations = &loop->annotations; + } else if (const auto* block = sref->StmtAs()) { + annotations = &block->annotations; + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + } + // Remove the annotation + ICHECK(annotations->find(ann_key) != annotations->end()) + << "IndexError: Cannot find annotation key: " << ann_key; + Map new_ann(*annotations); + new_ann.erase(ann_key); + // Create the new stmt + if (const auto* loop = sref->StmtAs()) { + ObjectPtr n = make_object(*loop); + n->annotations = std::move(new_ann); + self->Replace(sref, For(n), {}); + } else if (const auto* block = sref->StmtAs()) { + ObjectPtr n = make_object(*block); + n->annotations = std::move(new_ann); + Block p(n); + self->Replace(sref, p, {{GetRef(block), p}}); + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + throw; + } +} + +struct AnnotateTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Annotate"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, + String ann_key) { + if (const auto* block = block_or_loop_rv.as()) { + return sch->Annotate(GetRef(block), ann_key, ann_val); + } + if (const auto* loop = block_or_loop_rv.as()) { + return sch->Annotate(GetRef(loop), ann_key, ann_val); + } + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } + + static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, + ObjectRef ann_val, String ann_key) { + PythonAPICall py("annotate"); + py.Input("block_or_loop", block_or_loop_rv); + py.Input("ann_key", ann_key); + if (const auto* int_imm = ann_val.as()) { + py.Input("ann_val", std::to_string(int_imm->value)); + } else if (const auto* str_imm = ann_val.as()) { + py.Input("ann_val", GetRef(str_imm)); + } else if (const auto* expr = ann_val.as()) { + std::ostringstream os; + os << GetRef(expr); + py.Input("ann_val", os.str()); + } else { + LOG(FATAL) << "TypeError: Cannot handle type: " << ann_val->GetTypeKey(); + throw; + } + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +struct UnannotateTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Unannotate"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) { + if (const auto* block = block_or_loop_rv.as()) { + return sch->Unannotate(GetRef(block), ann_key); + } + if (const auto* loop = block_or_loop_rv.as()) { + return sch->Unannotate(GetRef(loop), ann_key); + } + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } + + static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, + String ann_key) { + PythonAPICall py("unannotate"); + py.Input("block_or_loop", block_or_loop_rv); + py.Input("ann_key", ann_key); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); +TVM_REGISTER_INST_KIND_TRAITS(UnannotateTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc new file mode 100644 index 0000000000..0cfcea06a4 --- /dev/null +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -0,0 +1,1078 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../../arith/pattern_match.h" +#include "../../ir/functor_common.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +bool CheckOneLine(const Stmt& s) { + bool legal = true, meet_block = false; + PostOrderVisit(s, [&legal, &meet_block](const ObjectRef& obj) { + if (obj->IsInstance() && !meet_block) { + legal = false; + } else if (obj->IsInstance()) { + meet_block = true; + } + }); + return legal; +} + +Block GetRootBlock(StmtSRef sref) { + const StmtSRefNode* p_sref = sref.get(); + while (p_sref->parent != nullptr) { + p_sref = p_sref->parent; + } + const BlockNode* root_block = TVM_SREF_TO_BLOCK(root_block, GetRef(p_sref)); + return GetRef(root_block); +} + +void RecalculateCachedFlags(ScheduleStateNode* self) { + ScheduleState new_state(self->mod); + for (const auto& kv : new_state->stmt2ref) { + const StmtNode* stmt = kv.first; + const StmtSRef& new_sref = kv.second; + if (stmt->IsInstance() || !self->stmt2ref.count(stmt)) { + continue; + } + const BlockInfo& new_block_info = new_state->block_info.at(new_sref); + const StmtSRef& old_sref = self->stmt2ref.at(stmt); + BlockInfo& old_block_info = self->block_info.at(old_sref); + old_block_info.affine_binding = new_block_info.affine_binding; + old_block_info.region_cover = new_block_info.region_cover; + old_block_info.scope->stage_pipeline = new_block_info.scope->stage_pipeline; + } +} + +void UpdateScope(ScheduleState self, const StmtSRef& block_sref) { + BlockScope scope(tir::GetChildBlocks(self, block_sref)); + // The caller is responsible for correcting the flags + bool affine_binding = false; + bool region_cover = false; + // TODO(@Wuwei): stage_pipeline + self->block_info[block_sref] = BlockInfo(std::move(scope), affine_binding, region_cover); +} + +/* \brief Deep comparison to check if two IR graph are equivalent */ +using ExprComparator = ExprFunctor; +using StmtComparator = StmtFunctor; + +class TensorizeComparator : public ExprComparator, public StmtComparator { + public: + explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {} + + // Map from rhs buffer to lhs buffer + std::unordered_map rhs_buffer_map_; + // Buffer indices mapping + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; + std::vector extra_block_vars_; + // variable remap if any + std::unordered_map equal_map_; + + bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; + bool VisitStmt(const Stmt& n, const Stmt& other) override; + + bool VisitStmt_(const ForNode* op, const Stmt& other) override; + bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + + bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; + bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; + bool VisitExpr_(const NENode* op, const PrimExpr& other) override; + bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const LENode* op, const PrimExpr& other) override; + bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const GENode* op, const PrimExpr& other) override; + bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; + bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; + bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + + bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); + virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + template + bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + bool CompareRange(const Range& lhs, const Range& rhs); + bool CompareType(const DataType& lhs, const DataType& rhs); + + protected: + bool assert_mode_; + bool is_scope_block = true, is_inner_block = true; +}; + +bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { + if (n.same_as(other)) return true; + if (n->type_index() != other->type_index()) return false; + bool equal = StmtComparator::VisitStmt(n, other); + if (!equal && assert_mode_) + LOG(FATAL) << "Stmts are not matching between:\n" << n << "\nand\n" << other; + return equal; +} + +bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + if (!VisitExpr(op->min, rhs->min)) return false; + if (!VisitExpr(op->extent, rhs->extent)) return false; + if (!VisitStmt(op->body, rhs->body)) return false; + if (op->kind != rhs->kind) return false; + if (op->thread_binding.defined() ^ rhs->thread_binding.defined()) return false; + if (op->thread_binding.defined() && + !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) + return false; + return CompareAnnotationMap(op->annotations, rhs->annotations); +} + +bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); +} + +bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Skip Compare binding values if the block is scope block (the outermost one). + if (!is_scope_block) { + size_t offset = op->iter_values.size() - rhs->iter_values.size(); + if (rhs->iter_values.size() > op->iter_values.size()) return false; + if (is_inner_block) { + // weak pattern matching for the inner block (the son of the scope block) + // where the pattern is v + iter <=> expr + iter + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + PrimExpr lhs_expr, rhs_expr; + Optional lhs_iter, rhs_iter; + auto detect = [](const PrimExpr& binding) -> std::pair> { + arith::PVar expr; + arith::PVar iter; + if (iter.Match(binding)) { + return std::make_pair(0, iter.Eval()); + } else if ((expr + iter).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else if ((iter + expr).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else { + return std::make_pair(expr.Eval(), NullOpt); + } + }; + std::tie(lhs_expr, lhs_iter) = detect(op->iter_values[i + offset]); + std::tie(rhs_expr, rhs_iter) = detect(rhs->iter_values[i]); + CHECK((lhs_iter && rhs_iter) || (!lhs_iter && !rhs_iter)) << "Incompatible binding"; + if (lhs_iter) VisitExpr(lhs_iter.value(), rhs_iter.value()); + if (is_zero(rhs_expr)) { + CHECK(is_zero(lhs_expr)) << "Incompatible binding"; + } else { + const auto* bv = rhs_expr.as(); + if (!bv) { + VisitExpr(lhs_expr, rhs_expr); + } else { + auto it = equal_map_.find(GetRef(bv)); + if (it == equal_map_.end()) { + equal_map_[GetRef(bv)] = lhs_expr; + } else { + CHECK(it->second->IsInstance()); + VisitExpr(lhs_expr, Downcast(it->second)); + } + } + } + } + } else { + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + if (!VisitExpr(op->iter_values[i + offset], rhs->iter_values[i])) return false; + } + const Block& block = op->block; + for (size_t i = 0; i < offset; ++i) { + Var block_var = Downcast(op->iter_values[i]); + auto it = equal_map_.find(block_var); + equal_map_[block->iter_vars[i]->var] = (it == equal_map_.end() ? block_var : it->second); + } + } + } + + return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); +} + +bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equal + // All iter var and buffer region should matches including the order + + // Check iterVar + // need to use DefEqual to remap vars + // Note: + // We only compare the inner most several axis + if (op->iter_vars.size() < rhs->iter_vars.size()) return false; + + size_t offset = op->iter_vars.size() - rhs->iter_vars.size(); + for (size_t i = 0; i < rhs->iter_vars.size(); ++i) { + auto lhs_var = op->iter_vars[i + offset], rhs_var = rhs->iter_vars[i]; + // Skip iter dom + if (!DefEqual(lhs_var->var, rhs_var->var)) return false; + if (lhs_var->iter_type != rhs_var->iter_type) return false; + } + + for (size_t i = 0; i < offset; ++i) { + if (is_scope_block) { + extra_block_vars_.push_back(op->iter_vars[i]); + } + } + + if (!is_scope_block) { + if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + return false; + } + } + if (!is_scope_block) is_inner_block = false; + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +// Exprs +#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ + bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ + } + +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); + +bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + auto lhs = GetRef(op); + if (lhs.same_as(other)) return true; + if (!CompareType(op->dtype, rhs->dtype)) return false; + auto it = equal_map_.find(lhs); + return it != equal_map_.end() && it->second.same_as(other); +} + +bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +bool TensorizeComparator::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs->type_index() != rhs->type_index()) return false; + auto it = equal_map_.find(lhs); + // If there is already a mapping + if (it != equal_map_.end()) return it->second.same_as(rhs); + equal_map_[lhs] = rhs; + return true; +} + +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first != rhs.first) return false; + if (!lhs.second.same_as(rhs.second)) return false; + return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); +} + +bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, + const Map& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + + auto sort_map = + [](const Map& map) -> std::vector> { + std::vector> ret; + ret.reserve(map.size()); + for (const auto& pair : map) { + ret.emplace_back(pair); + } + sort(ret.begin(), ret.end()); + return ret; + }; + + auto lhs_array = sort_map(lhs), rhs_array = sort_map(rhs); + + for (size_t i = 0; i < lhs.size(); ++i) { + if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + // Remap both buffer itself and buffer data + // Skip buffer shape + bool equal = DefEqual(lhs, rhs) && DefEqual(lhs->data, rhs->data) && + CompareType(lhs->dtype, rhs->dtype) && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } else if (assert_mode_) { + LOG(FATAL) << "Buffers are not matching between:" << lhs << " and " << rhs; + } + return equal; +} + +bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { + // Only for block region declaration + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + // Number of indices in desc_block must be smaller than it in AST + if (rhs->region.size() > lhs->region.size()) return false; + + std::vector lhs_region; + for (const auto& range : lhs->region) { + lhs_region.push_back(Range::FromMinExtent(range->min, range->extent)); + } + // special judge size 1 buffer + if (rhs->region.size() == 1 && is_zero(rhs->region[0]->min) && is_one(rhs->region[0]->extent)) { + lhs_region.push_back(Range::FromMinExtent(0, 1)); + } + size_t offset = lhs_region.size() - rhs->region.size(); + // initialize buffer indices + bool need_update = false; + if (!buffer_indices_.count(lhs->buffer)) { + need_update = true; + buffer_indices_[lhs->buffer] = std::vector(); + } else { + if (offset != buffer_indices_[lhs->buffer].size()) return false; + } + std::vector& indices = buffer_indices_[lhs->buffer]; + for (size_t i = 0; i < offset; ++i) { + const Range& range = lhs_region[i]; + // High-dim region must be element-wise + if (!is_one(range->extent)) return false; + if (need_update) { + indices.push_back(range->min); + } else { + // The order matters since we only map inner block_var to outside block_var + if (!VisitExpr(range->min, indices[i])) return false; + } + } + for (size_t i = 0; i < rhs->region.size(); ++i) { + if (!CompareRange(lhs_region[i + offset], rhs->region[i])) return false; + } + return true; +} + +// Only for BufferStoreNode and BufferLoadNode +template +bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + + if (rhs->indices.size() > lhs->indices.size()) return false; + // special judge size 1 buffer + if (rhs->indices.size() == 1 && is_zero(rhs->indices[0])) return true; + // otherwise + size_t offset = lhs->indices.size() - rhs->indices.size(); + for (size_t i = 0; i < rhs->indices.size(); ++i) { + if (!VisitExpr(lhs->indices[i + offset], rhs->indices[i])) return false; + } + return true; +} + +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(this->*cmp)(lhs[i], rhs[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { + return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); +} + +bool TensorizeComparator::CompareType(const DataType& lhs, const DataType& rhs) { + if (lhs == rhs) return true; + return lhs.code() == rhs.code() && lhs.bits() == rhs.bits() && lhs.lanes() == rhs.lanes(); +} + +// Deep comparison to check if two IR graph are equivalent +bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { + bool equal = (n->type_index() == other->type_index()) && ExprComparator::VisitExpr(n, other); + if (!equal && assert_mode_) + LOG(FATAL) << "Exprs are not matching between:" << n << " and " << other; + return equal; +} + +Array> TrivialSubspaceDivision(const Array& iter_vars, + const Array& bindings, + const std::vector& outer_loops, + const std::vector& inner_loops, + const PrimExpr& predicate) { + if (!is_one(predicate)) return {}; + std::vector> res; + std::unordered_set outer_loop_vars; + std::unordered_set inner_loop_vars; + for (const Var& var : outer_loops) { + outer_loop_vars.insert(var.get()); + } + for (const Var& var : inner_loops) { + inner_loop_vars.insert(var.get()); + } + for (size_t i = 0; i < bindings.size(); ++i) { + bool outer = UsesVar( + bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); + bool inner = UsesVar( + bindings[i], [&inner_loop_vars](const VarNode* var) { return inner_loop_vars.count(var); }); + bool is_var = bindings[i]->IsInstance(); + if (outer && !inner) { + arith::IterMark outer{nullptr}; + if (is_var) { + outer = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + outer = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + arith::IterMark inner(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else if (inner && !outer) { + arith::IterMark inner{nullptr}; + if (is_var) { + inner = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + inner = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + arith::IterMark outer(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else if (!outer && !inner) { + arith::IterMark outer(arith::IterSumExpr({}, 0), 1); + arith::IterMark inner(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else { + return {}; + } + } + res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), + arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); + return res; +} + +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { + /*! + * Check: + * - The sub AST is one-line with only one block + * + * Mutate: + * - extra block var from the only block + * - Update block binding + */ + const auto* loop = loop_sref->StmtAs(); + CHECK(loop) << "TypeError: Only support blockize a loop for now, but get type: " + << loop_sref->stmt->GetTypeKey(); + // check there exists no SeqStmt under loop + CHECK(CheckOneLine(GetRef(loop))) << "ValueError: Only one line subtree can be blockize"; + // get the inner Block, BlockRealize and StmtSRef + Array child_blocks = GetChildBlocks(self, loop_sref); + CHECK_EQ(child_blocks.size(), 1) << "ValueError: Only one line subtree can be blockize"; + StmtSRef block_sref = child_blocks[0]; + BlockRealize block_realize = GetBlockRealize(self, block_sref); + Block block = block_realize->block; + // collect loops inside/outside loop_sref + std::vector outer_loops, inner_loops; + std::vector outer_iters, inner_iters; + std::unordered_map iters; + bool inner = true; + for (StmtSRef current_sref = block_sref;;) { + current_sref = GetRef(current_sref->parent); + if (!current_sref.defined()) break; + const auto* current_loop = current_sref->StmtAs(); + if (!current_loop) break; + if (inner) { + inner_loops.push_back(current_loop); + inner_iters.push_back(current_loop->loop_var); + } else { + outer_loops.push_back(current_loop); + outer_iters.push_back(current_loop->loop_var); + } + iters[current_loop->loop_var] = Range::FromMinExtent(current_loop->min, current_loop->extent); + if (current_sref == loop_sref) inner = false; + } + arith::Analyzer analyzer; + Array> division = arith::SubspaceDivide( + block_realize->iter_values, iters, inner_iters, block_realize->predicate, false, &analyzer); + if (division.empty()) { + // It is possible to blockize if we can not do perfect subspace division if we can divide + // the block var bindings into two categories + // 1. The binding covers no inner loop var + // 2. The binding covers only inner loop vars + division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, outer_iters, + inner_iters, block_realize->predicate); + } + CHECK(!division.empty()) << "ValueError: The bindings of the block below can not be blockized"; + // Generate a new inner block + Array inner_block_vars, outer_block_vars; + Array inner_bindings, outer_bindings; + std::unordered_map block_var_no; + std::unordered_map bv_iters; + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const arith::IterMapExprNode* outer_binding = + division[i][0]->source.as(); + const arith::IterMapExprNode* inner_binding = + division[i][1]->source.as(); + ICHECK(outer_binding); + ICHECK(inner_binding); + if (is_one(division[i][1]->extent)) { // IsOuter + // extract this iter var to outer block directly + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_block_vars.push_back(iter_var); + // bv_iters[iter_var->var] = Range::FromMinExtent(0, division[i][0]->extent); + } else { + const IterVar outer_var(Range::FromMinExtent(0, division[i][0]->extent), + iter_var->var.copy_with_suffix("o"), iter_var->iter_type); + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_block_vars.push_back(outer_var); + // generate a new iter var for outer block + PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; + if (const auto* op = division[i][1]->source.as()) { + base = base + op->base; + inner_bindings.push_back(base + + arith::NormalizeIterMapToExpr(arith::IterSumExpr(op->args, 0))); + } else { + inner_bindings.push_back( + base + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); + } + inner_block_vars.push_back(iter_var); + bv_iters[iter_var->var] = Range::FromMinExtent(base, division[i][1]->extent); + } + block_var_no[iter_var->var] = i; + } + Block inner_block = block; + inner_block.CopyOnWrite()->iter_vars = inner_block_vars; + inner_block.CopyOnWrite()->init = NullOpt; + BlockRealize inner_br = block_realize; + inner_br.CopyOnWrite()->iter_values = inner_bindings; + inner_br.CopyOnWrite()->predicate = division.back()[1]->extent; + inner_br.CopyOnWrite()->block = inner_block; + // Regenerate inner_loops + Stmt body = inner_br; + for (const auto& inner_loop : inner_loops) { + auto loop_node = make_object(*inner_loop); + loop_node->body = body; + body = For(loop_node); + } + // Regenerate init for outer block + Optional new_init = NullOpt; + if (block->init.defined()) { + std::vector init_loops; + std::vector init_block_vars; + std::vector init_block_vars_copy; + std::vector init_bindings; + std::unordered_map binding_replace_map; + std::unordered_map bv_replace_map; + std::unordered_map new_block_vars2old_index; + for (size_t i = 0; i < inner_block_vars.size(); ++i) { + if (inner_block_vars[i]->iter_type == IterVarType::kDataPar && + UsesVar(block->init.value(), + [v = inner_block_vars[i]->var](const VarNode* var) { return var == v.get(); })) { + // copy init block vars and ignore reduce block vars + init_block_vars.push_back(i); + IterVar init_block_var = inner_block_vars[i]; + init_block_var.CopyOnWrite()->var = inner_block_vars[i]->var.copy_with_suffix("_init"); + init_block_vars_copy.push_back(init_block_var); + bv_replace_map[inner_block_vars[i]->var] = init_block_var->var; + new_block_vars2old_index[init_block_var.get()] = i; + } + } + for (const ForNode* inner_loop : inner_loops) { + for (size_t i = 0; i < init_block_vars.size(); ++i) { + if (UsesVar(inner_bindings[new_block_vars2old_index[init_block_vars_copy[i].get()]], + [v = inner_loop->loop_var](const VarNode* var) { return var == v.get(); })) { + // copy loops related to init block vars + For init_loop = GetRef(inner_loop); + init_loop.CopyOnWrite()->loop_var = inner_loop->loop_var.copy_with_suffix(""); + // replace loop vars with copied loop vars + binding_replace_map[inner_loop->loop_var] = init_loop->loop_var; + init_loops.push_back(init_loop); + break; + } + } + } + for (size_t i = 0; i < init_block_vars.size(); ++i) { + init_bindings.push_back(Substitute(inner_bindings[init_block_vars[i]], binding_replace_map)); + } + new_init = Substitute(Block(/*iter_vars=*/init_block_vars_copy, // + /*reads=*/{}, // + /*writes=*/block->writes, // + /*name_hint=*/block->name_hint + "_init", // + /*body=*/block->init.value(), // + /*init=*/NullOpt), + bv_replace_map); + new_init = + BlockRealize(init_bindings, division.back()[1]->extent, Downcast(new_init.value())); + for (const auto& init_loop : init_loops) { + For new_init_loop = init_loop; + new_init_loop.CopyOnWrite()->body = new_init.value(); + new_init = new_init_loop; + } + } + // Calculate outer block's IO region + auto rewrite_range = [&](const Range& range) -> Range { + const Array& res = + arith::DetectIterMap({range->min}, bv_iters, true, false, &analyzer); + ICHECK_EQ(res.size(), 1); + const arith::IterSumExpr& normalized_expr = res[0]; + PrimExpr extent = 1; + if (normalized_expr->args.size() == 1) { + CHECK(analyzer.CanProve(normalized_expr->args[0]->scale - range->extent == 0)); + extent = normalized_expr->args[0]->extent; + } + return Range::FromMinExtent(normalized_expr->base, extent * range->extent); + }; + std::vector reads, writes; + auto rewrite_region = [&](std::vector* regions, Array old_regions) { + for (auto buffer_region : old_regions) { + std::vector region; + for (const auto& range : buffer_region->region) { + region.push_back(rewrite_range(range)); + } + (*regions).emplace_back(buffer_region->buffer, region); + } + }; + rewrite_region(&reads, block->reads); + rewrite_region(&writes, block->writes); + // Generate a new outer block + auto outer_block = Block(/*iter_vars=*/outer_block_vars, // + /*reads=*/reads, // + /*writes=*/writes, // + /*name_hint=*/"blockized_" + block->name_hint, // + /*body=*/std::move(body), // + /*init=*/new_init); + auto outer_realize = BlockRealize(outer_bindings, division.back()[0]->extent, outer_block); + + self->Replace(loop_sref, outer_realize, {{block, inner_block}}); + { + StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_compact_dataflow*/ false); + UpdateScope(self, scope_sref); + } + RecalculateCachedFlags(self.operator->()); + + // } + // TODO(@wuwei): fix affine flags + // self->Replace(loop_sref, outer_realize, {{block, inner_block}}); + // { + // StmtSRef block_sref = self->stmt2ref.at(inner_block.get()); + // UpdateAffineFlag(self, block_sref); + // } + // { + // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + // StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + // /*require_compact_dataflow*/false); + // UpdateScope(self, scope_sref); + // UpdateAffineFlag(self, scope_sref); + // } + // { + // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + // UpdateScope(self, block_sref); + // UpdateAffineFlag(self, block_sref); + // } + + // // Check loop binding + + // { + // struct BindingValidator : public StmtVisitor { + // void VisitStmt_(const BlockRealizeNode* realize) final { + // StmtSRef& sref = self->stmt2ref.at(realize->block.get()); + // UpdateAffineFlag(self, sref); + // VisitStmt(realize->block->body); + // } + // ScheduleState self; + // }; + // BindingValidator validator; + // validator.self = self; + // const PrimFuncNode* func = GetRootPrimFunc(self->mod, GetRootBlock(loop_sref).get(), + // nullptr); validator(func->body); + // } + return self->stmt2ref.at(outer_block.get()); +} + +// Stmts + +void BufferRemap(const TensorIntrin& intrinsic, + std::unordered_map* buffer_map) { + ICHECK_EQ(intrinsic->description->params.size(), intrinsic->implementation->params.size()); + for (size_t i = 0; i < intrinsic->description->params.size(); ++i) { + const auto& lhs_var = intrinsic->description->params[i]; + const auto& lhs_buffer = intrinsic->description->buffer_map[lhs_var]; + const auto& rhs_var = intrinsic->implementation->params[i]; + const auto& rhs_buffer = intrinsic->implementation->buffer_map[rhs_var]; + (*buffer_map)[rhs_buffer] = lhs_buffer; + } +} + +// Replace buffer with its data, element_offset +class BufferReplacer : public StmtExprMutator { + public: + explicit BufferReplacer( + const std::unordered_map& buffer_map, + const std::unordered_map& var_map, + std::vector&& extra_block_vars, + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + buffer_indices) + : buffer_map_(buffer_map), + var_map_(var_map), + extra_block_vars_(std::move(extra_block_vars)), + buffer_indices_(buffer_indices) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + CHECK(op); + auto it = buffer_map_.find(op->buffer); + if (it != buffer_map_.end()) { + auto n = CopyOnWrite(op); + n->buffer = it->second; + auto it2 = buffer_indices_.find(n->buffer); + CHECK(it2 != buffer_indices_.end()); + n->indices.insert(n->indices.begin(), it2->second.begin(), it2->second.end()); + return Stmt(n); + } else { + return GetRef(op); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto s = StmtExprMutator::VisitExpr_(op); + op = s.as(); + CHECK(op); + auto it = buffer_map_.find(op->buffer); + if (it != buffer_map_.end()) { + auto n = make_object(*op); + n->buffer = it->second; + auto it2 = buffer_indices_.find(n->buffer); + CHECK(it2 != buffer_indices_.end()); + n->indices.insert(n->indices.begin(), it2->second.begin(), it2->second.end()); + return PrimExpr(n); + } else { + return GetRef(op); + } + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_map_.find(op); + if (it != var_map_.end()) { + return GetRef(it->second); + } else { + auto it2 = block_var_map_.find(op); + if (it2 != block_var_map_.find(op)) { + return GetRef(it2->second); + } else { + return GetRef(op); + } + } + } + + Stmt VisitStmt_(const BlockNode* op) final { + std::vector extra_block_var; + std::unordered_map block_var_map; + for (const auto& iter_var : extra_block_vars_) { + auto n = runtime::make_object(*(iter_var.get())); + IterVar block_var(n); + extra_block_var.push_back(block_var); + block_var_map[iter_var->var.get()] = block_var->var.get(); + } + std::swap(block_var_map, block_var_map_); + auto s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + CHECK(op); + + auto iter_vars = op->iter_vars; + iter_vars.insert(iter_vars.begin(), extra_block_var.begin(), extra_block_var.end()); + auto reads = UpdateBufferViaMap(op->reads); + auto writes = UpdateBufferViaMap(op->writes); + + std::swap(block_var_map, block_var_map_); + + if (reads.same_as(op->reads) && writes.same_as(op->writes) && + iter_vars.same_as(op->iter_vars)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->iter_vars = std::move(iter_vars); + return Block(n); + } + } + + private: + const std::unordered_map& buffer_map_; + const std::unordered_map& var_map_; + std::unordered_map block_var_map_; + const std::vector& extra_block_vars_; + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + buffer_indices_; + + Array UpdateBufferViaMap(const Array& buffer_regions) { + auto f_mutate = [this](const BufferRegion& buffer_region) { + auto it = buffer_map_.find(buffer_region->buffer); + if (it != buffer_map_.end()) { + auto n = make_object(*buffer_region.get()); + n->buffer = it->second; + auto it2 = buffer_indices_.find(n->buffer); + if (it2 != buffer_indices_.end()) { + Region region; + for (const auto& min : it2->second) { + region.push_back(Range::FromMinExtent(VisitExpr(min), 1)); + } + n->region.insert(n->region.begin(), region.begin(), region.end()); + } + while (n->region.size() > n->buffer->shape.size()) { + const Range& range = n->region.back(); + ICHECK(is_one(range->extent) && is_zero(range->min)); + n->region.pop_back(); + } + return BufferRegion(n); + } else { + return buffer_region; + } + }; + return MutateArray(buffer_regions, f_mutate); + } +}; + +void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin& intrinsic) { + /*! + * Check: + * - Check buffer binding, including type, alignment, shape and etc. + * - Check the sub AST is equal to the description function. + * + * Mutate: + * - Blockize the sub AST (please refer blockize for details) + * - Bind buffers + * - Mutate implement function with buffer binding + * - Replace the sub tree with the mutated function. + */ + const auto* loop = loop_sref->StmtAs(); + CHECK(loop) << "Only support tensorize a loop for now"; + + const auto* desc_block_realize = + Downcast(intrinsic->description->body)->block->body.as(); + const Block& desc_block = desc_block_realize->block; + const auto* impl_block_realize = + Downcast(intrinsic->implementation->body)->block->body.as(); + Block impl_block = impl_block_realize->block; + + const StmtSRef& block_sref = Blockize(self, loop_sref); + const BlockRealize& block_realize = GetBlockRealize(self, block_sref); + + TensorizeComparator comparator; + bool equal = comparator.VisitStmt(block_realize, GetRef(desc_block_realize)); + CHECK(equal) << "The AST subtree does not match intrinsic description"; + // Map from intrinsic func buffer to description func buffer + std::unordered_map intrin_buffer_map; + BufferRemap(intrinsic, &intrin_buffer_map); + // Map form intrinsic func buffer to current AST buffer + std::unordered_map buffer_map; + for (const auto& pair : intrin_buffer_map) { + auto it = comparator.rhs_buffer_map_.find(pair.second); + CHECK(it != comparator.rhs_buffer_map_.end()); + buffer_map[pair.first] = it->second; + } + // Build Var map, which is the map from intrin buffer data to AST buffer data + std::unordered_map var_map; + auto update_var_map = [&var_map](const PrimExpr& lhs, const PrimExpr& rhs) { + if (const auto* var = lhs.as()) { + var_map[var] = rhs.get(); + } + }; + for (const auto& pair : buffer_map) { + update_var_map(pair.first->data, pair.second->data); + } + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; + for (const auto& read : impl_block->reads) { + buffer_region_map.emplace(read->buffer, read->region); + } + for (const auto& write : impl_block->writes) { + buffer_region_map.emplace(write->buffer, write->region); + } + + Array match_buffer_regions; + for (size_t i = 0; i < intrinsic->implementation->params.size(); ++i) { + const auto& param = intrinsic->implementation->params[i]; + const auto& buffer = intrinsic->implementation->buffer_map.at(param); + const auto& source = buffer_map.at(buffer); + Region region = buffer_region_map.at(buffer); + auto extra_indices = comparator.buffer_indices_.at(source); + std::vector extra_buffer_ranges; + std::transform(extra_indices.begin(), extra_indices.end(), + std::back_inserter(extra_buffer_ranges), + [](const PrimExpr& index) { return Range::FromMinExtent(index, 1); }); + region.insert(region.begin(), extra_buffer_ranges.begin(), extra_buffer_ranges.end()); + match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, region))); + } + + impl_block.CopyOnWrite()->match_buffers = match_buffer_regions; + std::unordered_map bv_map; + for (size_t i = 0; i < desc_block->iter_vars.size(); ++i) { + auto it = comparator.equal_map_.find(desc_block->iter_vars[i]->var); + if (it != comparator.equal_map_.end()) { + bv_map[impl_block->iter_vars[i]->var] = Downcast(it->second); + } else { + bv_map[impl_block->iter_vars[i]->var] = 0; + } + } + Stmt new_body = SubstituteInScope(impl_block, [&](const VarNode* var) -> PrimExpr { + auto it = bv_map.find(GetRef(var)); + if (it == bv_map.end()) + return GetRef(var); + else + return it->second; + }); + // Replace + ObjectPtr new_block_ptr = make_object(*block_realize->block.get()); + new_block_ptr->body = Downcast(new_body)->body; + ICHECK(new_block_ptr->match_buffers.empty()); + new_block_ptr->match_buffers = Downcast(new_body)->match_buffers; + Block new_block(new_block_ptr); + self->Replace(self->stmt2ref.at(block_realize->block.get()), new_block, + {{block_realize->block, new_block}}); + RecalculateCachedFlags(self.operator->()); + // { + // struct BindingValidator : public StmtVisitor { + // void VisitStmt_(const BlockRealizeNode* realize) final { + // StmtSRef& sref = self->stmt2ref.at(realize->block.get()); + // UpdateAffineFlag(self, sref); + // VisitStmt(realize->block->body); + // } + // ScheduleState self; + // }; + // BindingValidator validator; + // StmtSRef block_sref = self->stmt2ref.at(new_block.get()); + // const PrimFuncNode* func = GetRootPrimFunc(self->mod, GetRootBlock(block_sref).get(), + // nullptr); validator.self = self; validator(func->body); + // } +} + +struct BlockizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Blockize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Blockize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("blockize"); + py.Input("loop", loop_rv); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +struct TensorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Tensorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String intrin_name) { + return sch->Tensorize(loop_rv, intrin_name); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, String intrin_name) { + PythonAPICall py("tensorize"); + py.Input("loop", loop_rv); + py.Input("intrin", intrin_name); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 12ae021a88..0c86f7b698 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -60,11 +60,27 @@ class NotSingleReadWriteBuffer : public ScheduleError { bool is_read_; Block block_; - static Buffer GetSingleRead(const ScheduleState& self, const Block& block) { - if (block->reads.size() != 1) { + static Buffer GetSingleRead(const ScheduleState& self, const Block& block, + const StmtSRef& scope_root_sref) { + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers; + const BufferNode* read_buffer = nullptr; + for (const BufferRegion& read_region : block->reads) { + const BufferNode* buffer = read_region->buffer.get(); + if (buffer == read_buffer) { + continue; + } + if (buffer_writers.count(GetRef(buffer)) > 0) { + if (read_buffer != nullptr) { + throw NotSingleReadWriteBuffer(self->mod, true, block); + } + read_buffer = buffer; + } + } + if (read_buffer == nullptr) { throw NotSingleReadWriteBuffer(self->mod, true, block); } - return block->reads[0]->buffer; + return GetRef(read_buffer); } static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { @@ -167,7 +183,7 @@ class OpaqueAccessError : public ScheduleError { * \brief The base class of the inliner, which handles: * 1) Substitute a subtree with the specific block being inlined * 2) Update the block signature to reflect the changes of read/write/allocated buffers - * 3) Maintain a list of index variables and their substition of the buffer being inlined + * 3) Maintain a list of index variables and their substitution of the buffer being inlined */ class BaseInliner : public StmtExprMutator { protected: @@ -526,7 +542,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr producer_rhs_{nullptr}; }; -void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { +std::function ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); @@ -535,6 +551,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { /*require_stage_pipeline=*/true, /*require_subtree_compact_dataflow=*/false); // Step 2. Check completeness + CheckNotOutputBlock(self, producer_block_sref, scope_root_sref); CheckCompleteBlock(self, producer_block_sref, scope_root_sref); // Step 3. Analyze the block body ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref); @@ -550,17 +567,32 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 6. Do the real mutation on the AST and the sref tree in the schedule state - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; } -void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { +void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { + ComputeInlineImpl(self, producer_block_sref)(); +} + +bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) { + try { + ComputeInlineImpl(self, producer_block_sref); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +std::function ReverseComputeInlineImpl(ScheduleState self, + const StmtSRef& consumer_block_sref) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); Block consumer_block = GetRef(_consumer_block); - Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // /*require_stage_pipeline=*/true, /*require_subtree_compact_dataflow=*/false); + Buffer inlined_buffer = + NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer @@ -579,7 +611,20 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 7. Do the real mutation on the AST and the sref tree in the schedule state - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; +} + +bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { + try { + ReverseComputeInlineImpl(self, block_sref); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { + ReverseComputeInlineImpl(self, consumer_block_sref)(); } /******** InstructionKind Registration ********/ diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 55869e12b6..acab85460a 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -83,7 +83,7 @@ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, const Block& block = block_realize->block; // Cond 1. The block is required to have affine bindings. - CheckAffineBinding(self, block); + /* CheckAffineBinding(self, block); */ // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc new file mode 100644 index 0000000000..2656fe7ba9 --- /dev/null +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + return true; + } + } + return false; +} + +void RelaxBufferRegions(const Array& buffer_regions, + const Buffer& buffer, // + const Map& var_dom, // + const Map& bindings, // + std::vector* relaxed_regions) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + Array relaxed_region = + arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); + relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); + } + } +} + +class ScopeReplacer : public StmtMutator { + public: + static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, + const ForNode* new_loop) { + ObjectPtr new_scope_block = make_object(*scope_block); + new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); + new_scope_block->alloc_buffers.push_back(dst); + return Block(new_scope_block); + } + + private: + explicit ScopeReplacer(const ForNode* old_loop, const ForNode* new_loop) + : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} + + Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } + Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == old_loop_) { + found_ = true; + return GetRef(new_loop_); + } + return StmtMutator::VisitStmt_(loop); + } + + const ForNode* old_loop_; + const ForNode* new_loop_; + bool found_; +}; + +class BufferReplacer : public StmtExprMutator { + public: + explicit BufferReplacer(const Buffer& src, const Buffer& dst, Map* block_sref_reuse) + : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + return BufferStore(new_store); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + return BufferLoad(new_load); + } + return load; + } + + Stmt VisitStmt_(const BlockNode* _block) final { + Block old_block = GetRef(_block); + Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = make_object(*block.get()); + new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); + new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); + block_sref_reuse_->Set(old_block, Block(new_block)); + return Block(new_block); + } + + const Buffer& src_; + const Buffer& dst_; + Map* block_sref_reuse_; +}; + +struct ReadWriteAtImpl { + template + static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope, + Map annotations) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer src = + GetNthAccessBuffer(self, GetRef(block), buffer_index, /*is_write=*/!is_read); + Buffer dst = WithScope(src, storage_scope); + ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); + std::pair new_loop_block = + impl.MakeLoopAndBlock(src->name + "_" + storage_scope); + StmtSRef result_block_sref = + impl.ReplaceScopeBlock(new_loop_block.first.get(), new_loop_block.second->block.get()); + impl.UpdateBlockInfo(result_block_sref); + return result_block_sref; + } + + private: + static Map GetLoopDomain(const StmtSRefNode* loop_sref) { + Map result; + for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; + loop_sref = loop_sref->parent) { + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return result; + } + + StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const BlockNode* new_block) { + StmtSRef scope_root_sref = GetScopeRoot(self_, loop_sref_, + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_root_sref); + Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); + block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); + return self_->stmt2ref.at(new_block); + } + + void UpdateBlockInfo(const StmtSRef& new_block_sref) { + BlockInfo& block_info = self_->block_info[new_block_sref]; + block_info.affine_binding = true; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + } + + template + std::pair MakeLoopAndBlock(const String& new_block_name_hint) { + Array subtrees = AsArray(loop_->body); + int n_subtrees = subtrees.size(); + runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); + std::vector relaxed_regions; + std::vector r_pos; + std::vector w_pos; + relaxed_regions.reserve(n_subtrees); + r_pos.reserve(n_subtrees); + w_pos.reserve(n_subtrees); + // Step 1. Iterate over all subtrees + for (int i = 0; i < n_subtrees; ++i) { + bool r_visited = false; + bool w_visited = false; + auto f_visit = [this, &relaxed_regions, &r_visited, &w_visited, + &scope](const ObjectRef& obj) -> bool { + const BlockRealizeNode* realize = obj.as(); + if (realize == nullptr) { + return true; + } + const BlockNode* block = realize->block.get(); + bool has_r = HasBuffer(block->reads, src_); + bool has_w = HasBuffer(block->writes, src_); + r_visited = r_visited || has_r; + w_visited = w_visited || has_w; + if (is_read ? has_r : has_w) { + RelaxBufferRegions( + /*buffer_regions=*/is_read ? block->reads : block->writes, + /*buffer=*/src_, + /*var_dom=*/ + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*high_exclusive=*/loop_sref_, + /*extra_relax_scope=*/scope)), + /*bindings=*/GetBindings(GetRef(realize)), + /*relaxed_regions=*/&relaxed_regions); + } + return false; + }; + PreOrderVisit(subtrees[i], f_visit); + if (r_visited) { + r_pos.push_back(i); + } + if (w_visited) { + w_pos.push_back(i); + } + } + // Step 2. Calculate `insert_pos` and [st, ed) for buffer replacement + int insert_pos = -1, st = -1, ed = -1; + if (is_read) { + ICHECK(!r_pos.empty()); + // No write after the first read + ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); + // Can be inserted at [0, r_pos.front()], i.e. before the first read + insert_pos = r_pos.front(); + // Buffer reads in [insert_pos, +oo) is rewritten + st = insert_pos; + ed = n_subtrees; + } else { + ICHECK(!w_pos.empty()); + // No read after the last write + ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); + // Can be inserted into (w_pos.back(), +oo), i.e. after the last write + insert_pos = w_pos.back() + 1; + st = 0; + ed = insert_pos; + } + // Step 3. Calculate `domain`, the domain of buffer access + NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); + int ndim = relaxed.size(); + Array domain; + domain.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const arith::IntSet& int_set = relaxed[i]; + PrimExpr min = analyzer_->Simplify(int_set.min()); + PrimExpr extent = analyzer_->Simplify(int_set.max() + 1 - min); + domain.push_back(Range::FromMinExtent(min, extent)); + } + // Step 4. Insert the auto copy block and replace buffers + BufferReplacer replacer(src_, dst_, &block_sref_reuse_); + for (int i = st; i < ed; ++i) { + Stmt stmt = subtrees[i]; + subtrees.Set(i, Stmt(nullptr)); + subtrees.Set(i, replacer(std::move(stmt))); + } + BlockRealize realize = + is_read + ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) + : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); + subtrees.insert(subtrees.begin() + insert_pos, realize); + ObjectPtr new_loop = make_object(*loop_); + new_loop->body = SeqStmt(std::move(subtrees)); + return {For(new_loop), realize}; + } + + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, + const Map& loop_domain, Array domain) const { + int n = domain.size(); + std::vector loop_vars; + loop_vars.reserve(n); + for (int i = 0; i < n; ++i) { + loop_vars.push_back(Var("ax" + std::to_string(i))); + } + Map bindings; + Array iter_vars; + Array iter_values; + Array indices; + iter_vars.reserve(n); + iter_values.reserve(n); + indices.reserve(n); + for (int i = 0; i < n; ++i) { + auto f_substitute = [&loop_domain, &bindings, &iter_vars, + &iter_values](const Var& var) -> Optional { + auto it = bindings.find(var); + if (it != bindings.end()) { + return (*it).second; + } + Range range = loop_domain.at(var); + ObjectPtr v = make_object(*var.get()); + v->name_hint = "v" + std::to_string(iter_vars.size()); + bindings.Set(var, Var(v)); + iter_values.push_back(var); + iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); + return Var(v); + }; + ObjectPtr dom = make_object(*domain[i].get()); + dom->min = Substitute(std::move(dom->min), f_substitute); + dom->extent = Substitute(std::move(dom->extent), f_substitute); + domain.Set(i, Range(dom)); + } + for (int i = 0; i < n; ++i) { + indices.push_back(domain[i]->min + loop_vars[i]); + } + Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); + for (int i = n - 1; i >= 0; --i) { + stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); + } + return BlockRealize( + /*values=*/iter_values, + /*predicate=*/const_true(), + Block(/*iter_vars=*/iter_vars, + /*reads=*/{BufferRegion(copy_from, domain)}, + /*writes=*/{BufferRegion(copy_to, domain)}, + /*name_hint=*/name_hint, // + /*body=*/std::move(stmt), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations_)); + } + + explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, + const Buffer& dst, Map annotations) + : self_(self), + loop_sref_(loop_sref), + loop_(nullptr), + src_(src), + dst_(dst), + annotations_(annotations), + block_sref_reuse_(), + analyzer_(std::make_unique()) { + loop_ = TVM_SREF_TO_FOR(loop_, loop_sref); + } + + ScheduleState self_; + const StmtSRef& loop_sref_; + const ForNode* loop_; + const Buffer& src_; + const Buffer& dst_; + Map annotations_; + Map block_sref_reuse_; + std::unique_ptr analyzer_; +}; + +StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, + {{"auto_copy", Integer(1)}}); +} + +StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, + storage_scope, {{"auto_copy", Integer(1)}}); +} + +/******** Instruction Registration ********/ + +struct ReadAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReadAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope); + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer read_buffer_index, String storage_scope) { + return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("read_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct WriteAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "WriteAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer write_buffer_index, String storage_scope) { + return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("write_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReadAtTraits); +TVM_REGISTER_INST_KIND_TRAITS(WriteAtTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 171838572d..40916a1a93 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -86,6 +86,7 @@ struct PrimeTable { pow_tab.emplace_back(std::move(tab)); } } + /*! * \brief Factorize a number n, and return in a cryptic format * \param n The number to be factorized @@ -187,6 +188,28 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { + ICHECK(!weights.empty()); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), + dist = std::uniform_real_distribution(0.0, sum), + sums = std::move(sums)]() mutable -> int32_t { + support::LinearCongruentialEngine rand_(&rng); + double p = dist(rand_); + int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int32_t n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; +} + std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; @@ -297,12 +320,12 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + const StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - int64_t extent = GetLoopIntExtent(loop); + const int64_t* extent = GetLoopIntExtent(loop); std::vector result; - if (extent == -1) { + if (extent == nullptr) { // Case 1. Handle loops with non-constant length result = std::vector(n_splits, 1); result[0] = -1; @@ -311,7 +334,7 @@ std::vector SamplePerfectTile( result = support::AsVector(decision->value()); int n = result.size(); ICHECK_GE(n, 2); - int64_t len = extent; + int64_t len = *extent; for (int i = n - 1; i > 0; --i) { int64_t& l = result[i]; // A previous decision could become invalid because of the change of outer tiles @@ -325,13 +348,62 @@ std::vector SamplePerfectTile( result[0] = len; } else { // Case 3. Use fresh new sampling result - result = SamplePerfectTile(rand_state, extent, n_splits, max_innermost_factor); + result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor); ICHECK_LE(result.back(), max_innermost_factor); } *decision = support::AsArray(result); return result; } +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, + support::LinearCongruentialEngine::TRandState* rand_state, + const tir::StmtSRef& block_sref, Optional* decision) { + // Find all possible compute-at locations + Array loop_srefs = tir::CollectComputeLocation(self, block_sref); + int n = loop_srefs.size(); + // Extract non-unit loops + std::vector choices; + choices.reserve(n); + for (int i = 0; i < n; ++i) { + const int64_t* extent = tir::GetLoopIntExtent(loop_srefs[i]); + if (extent != nullptr) { + choices.push_back(i); + } + } + // The decision made, by default it is -1 + int i = -1; + if (decision->defined()) { + // Handle existing decision + const auto* int_imm = decision->as(); + int64_t decided = int_imm->value; + if (decided == -2 || decided == -1) { + i = decided; + } else { + for (int choice : choices) { + if (choice <= decided) { + i = choice; + } else { + break; + } + } + } + } else { + // Sample possible combinations + i = SampleInt(rand_state, -2, choices.size()); + if (i >= 0) { + i = choices[i]; + } + } + *decision = Integer(i); + if (i == -2) { + return tir::StmtSRef::InlineMark(); + } + if (i == -1) { + return tir::StmtSRef::RootMark(); + } + return loop_srefs[i]; +} + /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits { @@ -396,8 +468,37 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleComputeLocation"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 1; + + static LoopRV UnpackedApplyToSchedule(Schedule sch, // + BlockRV block_rv, // + Optional decision) { + return sch->SampleComputeLocation(block_rv, decision); + } + + static String UnpackedAsPython(Array outputs, // + String block_rv, // + Optional decision) { + PythonAPICall py("sample_compute_location"); + py.Input("block", block_rv); + py.Decision(decision); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); +TVM_REGISTER_INST_KIND_TRAITS(SampleComputeLocationTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a411e40b13..292865f23e 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -125,6 +125,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") .set_body_method(&ScheduleNode::SamplePerfectTile); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") + .set_body_method(&ScheduleNode::SampleComputeLocation); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); @@ -163,6 +165,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +/******** (FFI) Data movement ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt") + .set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); @@ -181,7 +187,46 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); /******** (FFI) Blockize & Tensorize ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") + .set_body_method(&ScheduleNode::Blockize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") + .set_body_typed([](Schedule self, LoopRV loop_rv, ObjectRef intrin) { + if (const auto* str = intrin.as()) { + return self->Tensorize(loop_rv, GetRef(str)); + } + if (const auto* p_intrin = intrin.as()) { + return self->Tensorize(loop_rv, GetRef(p_intrin)); + } + LOG(FATAL) << "TypeError: Cannot handle type: " << intrin->GetTypeKey(); + throw; + }); + /******** (FFI) Annotation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") + .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, + const ObjectRef& ann_val) { + if (const auto* block_rv = rv.as()) { + return self->Annotate(GetRef(block_rv), ann_key, ann_val); + } + if (const auto* loop_rv = rv.as()) { + return self->Annotate(GetRef(loop_rv), ann_key, ann_val); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") + .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { + if (const auto* block_rv = rv.as()) { + return self->Unannotate(GetRef(block_rv), ann_key); + } + if (const auto* loop_rv = rv.as()) { + return self->Unannotate(GetRef(loop_rv), ann_key); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index faeb0b9907..1be5ed06ac 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -897,7 +897,7 @@ class ChildReplacer : private StmtMutator { int seq_index_; }; -void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, +void ScheduleStateNode::Replace(const StmtSRef& _src_sref, const Stmt& tgt_stmt, const Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index d8c18f0de0..2af9076c7d 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -34,18 +34,13 @@ Trace::Trace(Array insts, Map decisions) { /**************** Utilities ****************/ -bool IsPostproc(const InstructionKind& inst_kind) { - static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); - return inst_kind.same_as(inst_enter_postproc); -} - int GetNumValidInstructions(const Array& insts, bool remove_postproc) { if (!remove_postproc) { return insts.size(); } int n_insts = 0; for (const Instruction& inst : insts) { - if (!IsPostproc(inst->kind)) { + if (!inst->kind->IsPostproc()) { ++n_insts; } else { break; @@ -242,7 +237,7 @@ void TraceNode::ApplyToSchedule( decision_provider) const { std::unordered_map rv_map; for (const Instruction& inst : this->insts) { - if (remove_postproc && IsPostproc(inst->kind)) { + if (remove_postproc && inst->kind->IsPostproc()) { break; } Array inputs = TranslateInputRVs(inst->inputs, rv_map); @@ -266,7 +261,7 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const { int i = 0; for (const Instruction& inst : this->insts) { const InstructionKind& kind = inst->kind; - if (remove_postproc && IsPostproc(kind)) { + if (remove_postproc && kind->IsPostproc()) { break; } json_insts.push_back(Array{ @@ -295,7 +290,7 @@ Array TraceNode::AsPython(bool remove_postproc) const { Array py_trace; py_trace.reserve(this->insts.size()); for (const Instruction& inst : this->insts) { - if (remove_postproc && IsPostproc(inst->kind)) { + if (remove_postproc && inst->kind->IsPostproc()) { break; } Array attrs; diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 4a028d1dad..a5d044ca31 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -29,7 +29,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } @@ -73,6 +73,20 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n return results; } +LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, + this->GetSRef(block_rv), &decision)); + + static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{result}), + /*decision=*/decision); + return result; +} + /******** Schedule: Get blocks & loops ********/ BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -250,6 +264,31 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReadAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("WriteAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -331,8 +370,67 @@ void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, /******** Schedule: Blockize & Tensorize ********/ +BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { + BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv); + static const InstructionKind& kind = InstructionKind::Get("Blockize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{new_block})); + return new_block; +} + +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { + ConcreteScheduleNode::Tensorize(loop_rv, intrin_name); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{intrin_name}, + /*outputs=*/{})); +} + /******** Schedule: Annotation ********/ +void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, + const ObjectRef& ann_val) { + ConcreteScheduleNode::Annotate(loop_rv, ann_key, ann_val); + static const InstructionKind& kind = InstructionKind::Get("Annotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, ann_val}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, + const ObjectRef& ann_val) { + ConcreteScheduleNode::Annotate(block_rv, ann_key, ann_val); + static const InstructionKind& kind = InstructionKind::Get("Annotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, ann_val}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { + ConcreteScheduleNode::Unannotate(loop_rv, ann_key); + static const InstructionKind& kind = InstructionKind::Get("Unannotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { + ConcreteScheduleNode::Unannotate(block_rv, ann_key); + static const InstructionKind& kind = InstructionKind::Get("Unannotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + /******** Schedule: Misc ********/ void TracedScheduleNode::EnterPostproc() { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ac36b9ca06..13cef8e3df 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -51,6 +51,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; + LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; @@ -72,6 +73,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -85,7 +91,13 @@ class TracedScheduleNode : public ConcreteScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) final; + void Tensorize(const LoopRV& loop_rv, const String& intrin_name) final; /******** Schedule: Annotation ********/ + void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; + void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const BlockRV& loop_rv, const String& ann_key) override; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index ffb6b2d526..fb3829c59a 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -136,5 +136,98 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } +/******** Utilities for tensorization ********/ + +class IRSubstituteInScope : public StmtExprMutator { + public: + explicit IRSubstituteInScope(std::function fmap) + : fmap_(std::move(fmap)) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = fmap_(op); + if (it.defined()) { + return it; + } else { + return GetRef(op); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + arith::Analyzer analyzer; + auto fmutate = [&](const PrimExpr& e) { return this->VisitExpr(e); }; + Array v = op->iter_values; + v.MutateByApply(fmutate); + PrimExpr pred = this->VisitExpr(op->predicate); + if (v.same_as(op->iter_values) && pred.same_as(op->predicate)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_values = std::move(v); + n->predicate = std::move(analyzer.Simplify(pred)); + return Stmt(n); + } + } + + private: + const std::function fmap_; +}; + +Stmt SubstituteInScope(const Stmt& stmt, + const std::function& value_func) { + return IRSubstituteInScope(value_func)(stmt); +} + +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return it->second; + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(stmt); +} + +PrimExpr SubstituteInScope(const PrimExpr& expr, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return it->second; + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(expr); +} + +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return GetRef(it->second); + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(stmt); +} + +PrimExpr SubstituteInScope(const PrimExpr& expr, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return GetRef(it->second); + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(expr); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c66c2ca766..091266ee38 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -178,6 +178,18 @@ inline Array AsArray(const Stmt& stmt) { return {stmt}; } +/*! + * \brief Checks of a statement is a SeqStmt that contains multiple statements + * \param stmt The statement to be checked + * \return A boolean indicating the result + */ +inline bool IsSingleStmt(const Stmt& stmt) { + if (const auto* seq_stmt = stmt.as()) { + return seq_stmt->seq.size() == 1; + } + return true; +} + /******** IterVar ********/ /*! @@ -192,6 +204,36 @@ inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_va Var(std::move(name), loop->loop_var.dtype()), iter_var_type); } +/*! + * \brief Get the thread scope bound to the specific loop + * \param loop The loop to be inspected + * \return The thread scope bound to the loop + */ +inline runtime::ThreadScope GetThreadScope(const ForNode* loop) { + if (loop->kind == ForKind::kThreadBinding) { + return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); + } + return runtime::ThreadScope{-1, -1}; +} + +/*! + * \brief Check if the thread scope is blockIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is blockIdx + */ +inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 0; // The rank of blockIdx is 0 +} + +/*! + * \brief Check if the thread scope is threadIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is threadIdx + */ +inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 1 && thread_scope.dim_index >= 0; +} + /******** Integer set ********/ /*! @@ -210,28 +252,115 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } -/**************** Loop extents ****************/ +/**************** PrimExpr parsing and extents ****************/ /*! * \brief Get the extents of a loop * \param loop The loop to be queried - * \return The extents of the loop + * \return The extent of the loop, nullptr if the extent is not constant */ -inline int64_t GetLoopIntExtent(const ForNode* loop) { - const auto* int_extent = loop->extent.as(); - return int_extent ? int_extent->value : -1; -} +inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_int(loop->extent); } /*! * \brief Get the extents of a loop * \param loop_sref The loop to be queried - * \return The extents of the loop + * \return The extent of the loop, nullptr if the extent is not constant */ -inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) { +inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - return GetLoopIntExtent(loop); + return as_const_int(loop->extent); } +/*! + * \brief Check if an expression consists of a single variable, + * or a variable plus/minus an constant integer shift + * \param expr The expression to be checked + * \return result Output, the var if it satisfies the condition; otherwise NullOpt + */ +inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { + if (const auto* var = expr.as()) { + *constant = NullOpt; + return GetRef(var); + } + arith::PVar var; + arith::PVar shift; + // match: "var + shift" + if ((var + shift).Match(expr) || (shift + var).Match(expr)) { + *constant = shift.Eval(); + return var.Eval(); + } + // match: "var - shift" + if ((var - shift).Match(expr)) { + IntImm result = shift.Eval(); + *constant = IntImm(result->dtype, -result->value); + return var.Eval(); + } + return NullOpt; +} + +/******** Annotation ********/ + +/*! + * \brief Get the annotation on a Block/For + * \tparam TObjectRef The type of the annotation value + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be looked up + * \return NullOpt if not found; otherwise the annotation value + */ +template +inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) { + const Map* annotations = &stmt->annotations; + for (const auto& ann : *annotations) { + if (ann.first == ann_key) { + return Downcast(ann.second); + } + } + return NullOpt; +} + +/*! + * \brief Get the annotation on a Block/For + * \tparam TObjectRef The type of the annotation value + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be looked up + * \return NullOpt if not found; otherwise the annotation value + */ +template +inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) { + if (const auto* loop = sref->StmtAs()) { + return GetAnn(loop, ann_key); + } else if (const auto* block = sref->StmtAs()) { + return GetAnn(block, ann_key); + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + throw; + } +} + +/*! + * \brief Substitute the var in current block scope specified in key->var to be value. + * \param stmt The source stmt to be substituted + * \param value_func The function of new values mapping. + * \return The converted stmt. + */ +Stmt SubstituteInScope(const Stmt& stmt, const std::function& value_func); + +/*! + * \brief Substitute the var in current block scope specified in var map + * \param stmt The source stmt to be substituted + * \param var_map The mapping of var + * \return The converted stmt + */ +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map); + +/*! + * \param var_map The mapping of var + * \return The converted stmt + */ +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map); + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index f7629d1006..ddc2e17569 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -80,7 +80,7 @@ class OpaqueBlockConverter : public StmtExprMutator { return std::move(new_realize); } - /*! \brief The map from block vars to thier binding values. */ + /*! \brief The map from block vars to their binding values. */ std::unordered_map var_substitutes_; }; diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 630c00f8c1..aa811b49d7 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -149,22 +149,22 @@ Array RemoveBufferFromBufferRegions(const Array& buf /*! * \brief Substitute a given source buffer with a given target buffer in statements or expressions */ -class BufferReplacer : private StmtExprMutator { +class BufferMutator : private StmtExprMutator { public: static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) { - return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt)); + return BufferMutator(src_buffer, tgt_buffer)(std::move(stmt)); } private: - explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer) + explicit BufferMutator(Buffer src_buffer, Buffer tgt_buffer) : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) : GetRef(load); } - Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt VisitStmt_(const BufferStoreNode* store) override { if (store->buffer.same_as(src_buffer_)) { PrimExpr value = StmtExprMutator::VisitExpr(store->value); return BufferStore(tgt_buffer_, value, {0}); @@ -287,7 +287,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionalwrites = {it_buffer_region.value()}; new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = - BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); + BufferMutator::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); new_block->init = NullOpt; ObjectPtr n = make_object(*realize); n->block = Block(new_block); diff --git a/tests/python/meta_schedule/tir_tensor_intrin.py b/tests/python/meta_schedule/tir_tensor_intrin.py new file mode 100644 index 0000000000..76f1920c27 --- /dev/null +++ b/tests/python/meta_schedule/tir_tensor_intrin.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A collection of TIR tensor intrinsics""" +# pylint: disable=missing-function-docstring +import tvm +from tvm import tir +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks +# fmt: off + +@T.prim_func +def tensorcore_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def tensorcore_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([ + C[vi : vi + 16, vj : vj + 16], + A[vi : vi + 16, vk : vk + 16], + B[vj : vj + 16, vk : vk + 16], + ]) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.R(4, v0 + i) + C[0] = C[0] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + T.reads([C[0 : 1], A[v0 : v0 + 4], B[v0 : v0 + 4]]) + T.writes([C[0 : 1]]) + T.evaluate(T.call_extern( # pylint: disable=redundant-keyword-arg + "vec4add", + C.data, C.elem_offset, + A.data, A.elem_offset, + B.data, B.elem_offset, + dtype="int32", + )) + +@T.prim_func +def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], "float32") * T.cast(B[vkk, vjj], + "float32") + + +@T.prim_func +def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, + scope="wmma.accumulator") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([C[vi: vi+16, vj: vj+16], A[vi: vi+16, vk: vk+16], B[vk: vk+16, vj: vj+16]]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + A.data, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), + B.data, B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16), + C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + dtype="handle")) + + +@T.prim_func +def wmma_load_a_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, + scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_a") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("load"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_a_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi+16, vj: vj+16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major", + dtype="handle")) + + +@T.prim_func +def wmma_load_b_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("load"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_b_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi+16, vj: vj+16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major", + dtype="handle")) + + +@T.prim_func +def wmma_fill_desc(c: T.handle) -> None: + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("init"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = T.float32(0) + + +@T.prim_func +def wmma_fill_impl(c: T.handle) -> None: + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads([]) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), T.float32(0), dtype="handle")) + + +@T.prim_func +def wmma_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("store"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0]) + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi + 16, vj: vj + 16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_store_matrix_sync( + A.data, 16, 16, 16, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), C.access_ptr("w"), s1, "row_major", + dtype="handle")) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks + +TENSORCORE_WMMA = tir.TensorIntrin.register( + "test.tensorcore.wmma", + tensorcore_desc, + tensorcore_impl, +) + +NEON_DOT = tir.TensorIntrin.register( + "test.neon.dot", + dot_product_desc, + dot_product_impl, +) + +WMMA_SYNC = tir.TensorIntrin.register( + "wmma_sync", + wmma_sync_desc, + wmma_sync_impl, +) + +WMMA_LOAD_A = tir.TensorIntrin.register( + "wmma_load_a", + wmma_load_a_desc, + wmma_load_a_impl, +) + +WMMA_LOAD_B = tir.TensorIntrin.register( + "wmma_load_b", + wmma_load_b_desc, + wmma_load_b_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_fill", + wmma_fill_desc, + wmma_fill_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_store", + wmma_store_desc, + wmma_store_impl, +) diff --git a/tests/python/unittest/test_meta_schedule_builder.py b/tests/python/unittest/test_meta_schedule_builder.py index fb3fa135a9..03476ddefa 100644 --- a/tests/python/unittest/test_meta_schedule_builder.py +++ b/tests/python/unittest/test_meta_schedule_builder.py @@ -201,7 +201,7 @@ def test_meta_schedule_error_handle_time_out(): def initializer(): @register_func("meta_schedule.builder.test_time_out") - def timeout_build(mod, target): # pylint: disable=unused-argument, unused-variable + def timeout_build(mod, target, _): # pylint: disable=unused-argument, unused-variable time.sleep(2) builder = LocalBuilder( diff --git a/tests/python/unittest/test_meta_schedule_byoc.py b/tests/python/unittest/test_meta_schedule_byoc.py new file mode 100644 index 0000000000..a420e41a72 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_byoc.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ +# pylint: disable=missing-docstring + +import sys + +import pytest +import tvm +from tvm import relay +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.runner import EvaluatorConfig, LocalRunner, RunnerInput +from tvm.meta_schedule.testing import get_network +from tvm.meta_schedule.testing.byoc_trt import ( + build_relay, + build_relay_with_tensorrt, + run_with_graph_executor, +) +from tvm.relay import testing +from tvm.relay.op.contrib import tensorrt +from tvm.target import Target +from tvm.tir import FloatImm + +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) + +# conv2d+relu network +def get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, +): + + data = relay.var("data", relay.TensorType(data_shape, dtype)) + weight = relay.var("weight") + + net = relay.nn.conv2d( + data=data, + weight=weight, # conv kernel + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + channels=out_channels, + kernel_size=kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + net = relay.add(net, net) + net = relay.nn.relu(net) + + inputs = relay.analysis.free_vars(net) + return relay.Function(inputs, net) + + +def verify_meta_schedule_with_tensorrt( + mod, + params, + data_shape, + use_meta_sched: bool = True, + use_trt: bool = True, + mode: str = "vm", +): + if use_meta_sched: + # With meta_schedule + dev = "nvidia/geforce-rtx-2080" + # Build + builder = LocalBuilder( + f_build=build_relay_with_tensorrt if use_trt else build_relay, + timeout_sec=1000, + ) + builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) + builder_result = builder.build([builder_input])[0] + assert builder_result.error_msg is None, builder_result.error_msg + assert builder_result.artifact_path is not None + + # Run + runner_input = RunnerInput( + builder_result.artifact_path, + device_type="cuda", + args_info=[TensorInfo("float32", data_shape)], + ) + runner = LocalRunner( + evaluator_config=EvaluatorConfig( + number=5, + repeat=2, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ), + f_run_evaluator=run_with_graph_executor, + ) + + # Run the module + runner_future = runner.run([runner_input])[0] + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.error_msg is None, runner_result.error_msg + assert runner_result.run_secs is not None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + else: + # Without meta_schedule + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + _func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + with tvm.transform.PassContext(opt_level=3): + _func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params + ).evaluate() + + +def test_conv2d_relu(): + data_shape = (1, 1280, 14, 14) + out_channels = 256 + kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 + data_layout, kernel_layout = "NCHW", "OIHW" + dtype = "float32" + + f = get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, + ) + + mod, params = testing.create_workload(f) + verify_meta_schedule_with_tensorrt(mod, params, data_shape) + + +@pytest.mark.parametrize( + "model_name", + ["resnet-50", "mobilenet"], +) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("use_trt", [True, False]) +def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): + mod, params, input_shape, _oshape = get_network( + name=model_name, + batch_size=batch_size, + ) + verify_meta_schedule_with_tensorrt( + mod, + params, + input_shape, + use_meta_sched=use_meta_sched, + use_trt=use_trt, + mode="vm", + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py new file mode 100644 index 0000000000..cdc72d30b6 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List + +import tempfile +import os +import re +import sys +import shutil +import pytest +import numpy as np + +import tvm +from tvm.script import tir as T +from tvm.tir.schedule.schedule import Schedule +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor +from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.tune_context import TuneContext + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,disable=unused-argument + + +def test_meta_schedule_cost_model(): + class FancyCostModel(PyCostModel): + def load(self, path: str) -> None: + pass + + def save(self, path: str) -> None: + pass + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10) + + model = FancyCostModel() + model.save("fancy_test_location") + model.load("fancy_test_location") + model.update(TuneContext(), [], []) + results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) + assert results.shape == (10,) + + +def test_meta_schedule_cost_model_as_string(): + class NotSoFancyCostModel(PyCostModel): + def load(self, path: str) -> None: + pass + + def save(self, path: str) -> None: + pass + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10) + + cost_model = NotSoFancyCostModel() + pattern = re.compile(r"NotSoFancyCostModel\(0x[a-f|0-9]*\)") + assert pattern.match(str(cost_model)) + + +def test_meta_schedule_random_model(): + model = RandomModel() + model.update(TuneContext(), [], []) + res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(10)]) + assert len(res) == 10 + assert min(res) >= 0 and max(res) <= model.max_range + + +def test_meta_schedule_random_model_reseed(): + model = RandomModel(seed=100) + res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)]) + new_model = RandomModel(seed=100) + new_res = new_model.predict( + TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)] + ) + assert (res == new_res).all() + + +def test_meta_schedule_random_model_reload(): + model = RandomModel(seed=25973) + model.predict( + TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(30)] + ) # change state + path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_random_model.npy") + model.save(path) + res1 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) + model.load(path) + res2 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) + shutil.rmtree(os.path.dirname(path)) + assert (res1 == res2).all() + + +def _dummy_candidate(): + return MeasureCandidate(Schedule(Matmul), []) + + +def _dummy_result(num_samples: int = 4, max_run_sec: int = 10): + return RunnerResult(list(np.random.rand(num_samples) * max_run_sec + 1e-6), None) + + +def test_meta_schedule_xgb_model(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=2) + update_sample_count = 10 + predict_sample_count = 100 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + + +def test_meta_schedule_xgb_model_reload(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=10) + update_sample_count = 20 + predict_sample_count = 30 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + random_state = model.extractor.random_state # save feature extractor's random state + path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_xgb_model.bin") + cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) + model.save(path) + res1 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + model.extractor.random_state = random_state # load feature extractor's random state + model.cached_features = None + model.cached_mean_costs = None + model.load(path) + new_cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) + res2 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + shutil.rmtree(os.path.dirname(path)) + assert (res1 == res2).all() + # cached feature does not change + assert len(cached[0]) == len(new_cached[0]) + for i in range(len(cached[0])): + assert (cached[0][i] == new_cached[0][i]).all() + # cached meaen cost does not change + assert (cached[1] == new_cached[1]).all() + + +def test_meta_schedule_xgb_model_reupdate(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=2) + update_sample_count = 60 + predict_sample_count = 100 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py new file mode 100644 index 0000000000..4f068d7a83 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +import re +import numpy as np + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.feature_extractor import PyFeatureExtractor + + +def test_meta_schedule_feature_extractor(): + class FancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, + tune_context: TuneContext, # pylint: disable = unused-argument + candidates: List[MeasureCandidate], # pylint: disable = unused-argument + ) -> List[np.ndarray]: + return [np.random.rand(4, 5)] + + extractor = FancyFeatureExtractor() + features = extractor.extract_from(TuneContext(), []) + assert len(features) == 1 + assert features[0].shape == (4, 5) + + +def test_meta_schedule_feature_extractor_as_string(): + class NotSoFancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, + tune_context: TuneContext, # pylint: disable = unused-argument + candidates: List[MeasureCandidate], # pylint: disable = unused-argument + ) -> List[np.ndarray]: + return [] + + feature_extractor = NotSoFancyFeatureExtractor() + pattern = re.compile(r"NotSoFancyFeatureExtractor\(0x[a-f|0-9]*\)") + assert pattern.match(str(feature_extractor)) + + +if __name__ == "__main__": + test_meta_schedule_feature_extractor() + test_meta_schedule_feature_extractor_as_string() diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py new file mode 100644 index 0000000000..210bc01499 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -0,0 +1,1536 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import Callable, List + +from numpy.testing import assert_allclose +import tvm +from tvm import meta_schedule as ms, te, tir +from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T + +N_FEATURES = 164 + + +def _make_context(target) -> ms.TuneContext: + return ms.TuneContext( + target=target, + num_threads=1, + ) + + +def _make_candidate(f_sch: Callable[[], tir.Schedule]) -> ms.MeasureCandidate: + return ms.MeasureCandidate(sch=f_sch(), args_info=[]) + + +def _feature_names( # pylint: disable=invalid-name + buffers_per_store: int = 5, + arith_intensity_curve_num_samples: int = 10, +) -> List[str]: + result = [ + "float_mad", + "float_addsub", + "float_mul", + "float_divmod", + "float_cmp", + "float_mathfunc", + "float_otherfunc", + "int_mad", + "int_addsub", + "int_mul", + "int_divmod", + "int_cmp", + "int_mathfunc", + "int_otherfunc", + "bool_op", + "select_op", + "vec_num", + "vec_prod", + "vec_len", + "vec_type.kPosNone", + "vec_type.kPosInnerSpatial", + "vec_type.kPosMiddleSpatial", + "vec_type.kPosOuterSpatial", + "vec_type.kPosInnerReduce", + "vec_type.kPosMiddleReduce", + "vec_type.kPosOuterReduce", + "vec_type.kPosMixed", + "unroll_num", + "unroll_prod", + "unroll_len", + "unroll_type.kPosNone", + "unroll_type.kPosInnerSpatial", + "unroll_type.kPosMiddleSpatial", + "unroll_type.kPosOuterSpatial", + "unroll_type.kPosInnerReduce", + "unroll_type.kPosMiddleReduce", + "unroll_type.kPosOuterReduce", + "unroll_type.kPosMixed", + "parallel_num", + "parallel_prod", + "parallel_len", + "parallel_type.kPosNone", + "parallel_type.kPosInnerSpatial", + "parallel_type.kPosMiddleSpatial", + "parallel_type.kPosOuterSpatial", + "parallel_type.kPosInnerReduce", + "parallel_type.kPosMiddleReduce", + "parallel_type.kPosOuterReduce", + "parallel_type.kPosMixed", + "is_gpu", + "blockIdx_x_len", + "blockIdx_y_len", + "blockIdx_z_len", + "threadIdx_x_len", + "threadIdx_y_len", + "threadIdx_z_len", + "vthread_len", + ] + for i in range(buffers_per_store): + result.extend( + f"B{i}.{s}" + for s in [ + "acc_type.kRead", + "acc_type.kWrite", + "acc_type.kReadWrite", + "bytes", + "unique_bytes", + "lines", + "unique_lines", + "reuse_type.kLoopMultipleRead", + "reuse_type.kSerialMultipleReadWrite", + "reuse_type.kNoReuse", + "reuse_dis_iter", + "reuse_dis_bytes", + "reuse_ct", + "bytes_d_reuse_ct", + "unique_bytes_d_reuse_ct", + "lines_d_reuse_ct", + "unique_lines_d_reuse_ct", + "stride", + ] + ) + result.extend(f"arith_intensity_curve_{i}" for i in range(arith_intensity_curve_num_samples)) + result.extend( + [ + "alloc_size", + "alloc_prod", + "alloc_outer_prod", + "alloc_inner_prod", + "outer_prod", + "num_loops", + "auto_unroll_max_step", + ] + ) + # 57 + 18 * 5 + 10 + 4 + 3 + assert len(result) == N_FEATURES + return result + + +def _zip_feature(feature, names): + assert feature.ndim == 1 + assert feature.shape[0] == N_FEATURES + assert len(names) == N_FEATURES + return list(zip(names, feature)) + + +def _print_feature(feature, st, ed): # pylint: disable=invalid-name + named_feature = _zip_feature(feature, _feature_names()) + for k, v in named_feature[st:ed]: + print("\t", k, v) + + +def test_cpu_matmul(): + def _create_schedule(): + func = te.create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i_o, i_i = sch.split(i, factors=[None, 16]) # outer: 32 + j_o, j_i = sch.split(j, factors=[None, 8]) # outer: 64 + sch.reorder(i_o, j_o, k, j_i, i_i) + sch.vectorize(j_i) + sch.parallel(i_o) + sch.parallel(j_o) + sch.unroll(k) + return sch + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("llvm")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (1, N_FEATURES) + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[ + # float math ops + 0, 27, 27, 0, 0, 0, 0, + # int math ops + 0, 29, 29, 0, 0, 0, 0, + # bool/select ops + 0, 0, + ], + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[1.0, 3.169924, 3.169924, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[1.0, 9.002815, 9.002815, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[1.58496, 11.0007, 6.022368, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1, + 0, + 0, + 29, + 20, + 27, + 14, + 1, + 0, + 0, + 4.087463, + 7.0552826, + 3.169925, + 26, + 17, + 24, + 11.0007038, + 9.002815, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 0.0, + 1.0, + 29.0, + 20.000001907348633, + 27.0, + 14.00008773803711, + 1.0, + 0.0, + 0.0, + 7.011227130889893, + 9.250298500061035, + 9.002815246582031, + 20.000001907348633, + 11.000703811645508, + 18.0000057220459, + 5.044394016265869, + 9.002815246582031, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Buffer B + assert_allclose( + actual=f[93:111], + desired=[ + 1.0, + 0.0, + 0.0, + 29.0, + 20.000001907348633, + 19.000001907348633, + 14.00008773803711, + 1.0, + 0.0, + 0.0, + 1.0, + 3.700439691543579, + 4.087462902069092, + 25.0, + 16.000022888183594, + 15.000043869018555, + 10.001408576965332, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[ + 0.7097842693328857, + 0.7408391237258911, + 0.8750449419021606, + 0.9449487924575806, + 1.0148526430130005, + 1.0847564935684204, + 1.113688349723816, + 1.1394684314727783, + 1.2119636535644531, + 1.2971993684768677, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 20.000001907348633, + 18.0000057220459, + 1.0, + 27.0, + 27.0, + 2.5849626064300537, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + + +def test_cpu_fusion(): + # pylint: disable=all + @T.prim_func + def func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [64, 32], dtype="float32") + B = T.match_buffer(b, [64, 32], dtype="float32") + C = T.match_buffer(c, [64, 32], dtype="float32") + for i, j in T.grid(64, 32): # type: ignore + with T.block(): + T.reads([A[i, j], B[i, j]]) # type: ignore + T.writes([B[i, j], C[i, j]]) # type: ignore + with T.block("B"): + T.reads([A[i, j]]) # type: ignore + T.writes([B[i, j]]) # type: ignore + B[i, j] = A[i, j] # type: ignore + with T.block("C"): + T.reads([B[i, j]]) # type: ignore + T.writes([C[i, j]]) # type: ignore + C[i, j] = B[i, j] # type: ignore + + # pylint: enable=all + + def _create_schedule(): + return tir.Schedule(func, debug_mask="all") + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("llvm")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (2, N_FEATURES) + ## Features for BufferStore(B) + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[0.0] * 16, + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer B + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 13.000176, + 11.000703811645508, + 1.0, + 11.000703811645508, + 11.000703811645508, + 1.5849624872207642, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ## Features for BufferStore(C) + f = feature[1] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[0.0] * 16, + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 1.0, + 0.0, + 1.0, + 4.087462902069092, + 1.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 13.000176429748535, + 11.000703811645508, + 1.0, + 11.000703811645508, + 11.000703811645508, + 1.5849624872207642, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + + +def test_gpu(): + def _create_schedule(): + n = m = k = 512 + func = te.create_prim_func(te_workload.matmul(n=n, m=m, k=k)) + sch = tir.Schedule(func, debug_mask="all") + c = sch.get_block("C") + c_local = sch.cache_write(c, 0, "local") + i, j, k = sch.get_loops(c) + # pylint: disable=invalid-name + i0, i1, i2, i3, i4 = sch.split(i, factors=[None, 1, 16, 32, 1]) # outer: 1 + j0, j1, j2, j3, j4 = sch.split(j, factors=[None, 4, 1, 1, 16]) # outer: 8 + k0, k1, k2 = sch.split(k, factors=[None, 1, 2]) # outer: 256 + # pylint: enable=invalid-name + # fmt: off + sch.reorder( + i0, j0, # S + i1, j1, # S + i2, j2, # S + k0, # R + k1, # R + i3, j3, # S + k2, # R + i4, j4, # S + ) + # fmt: on + # thread binding + i0_j0 = sch.fuse(i0, j0) + i1_j1 = sch.fuse(i1, j1) + i2_j2 = sch.fuse(i2, j2) + sch.bind(i0_j0, "blockIdx.x") + sch.bind(i1_j1, "vthread.x") + sch.bind(i2_j2, "threadIdx.x") + # fusion + sch.reverse_compute_at(c_local, i2_j2) + # cache read 'A' + a_shared = sch.cache_read(c, 1, "shared") + sch.compute_at(a_shared, k0) + _, _, _, _, a_i, a_j = sch.get_loops(a_shared) + a_ij = sch.fuse(a_i, a_j) + _, a_j = sch.split(a_ij, factors=[None, 16]) # outer: 64 + sch.bind(a_j, "threadIdx.x") + # cache read 'B' + b_shared = sch.cache_read(c, 2, "shared") + sch.compute_at(b_shared, k0) + _, _, _, _, b_i, b_j = sch.get_loops(b_shared) + b_ij = sch.fuse(b_i, b_j) + _, b_j = sch.split(b_ij, factors=[None, 16]) # outer: 8 + sch.bind(b_j, "threadIdx.x") + # auto unroll + sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tir.IntImm("int32", 1024)) + sch.annotate(i0_j0, "pragma_unroll_explicit", tir.IntImm("int32", 1)) + return sch + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("cuda")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (4, N_FEATURES) + ### Check feature[0]: BufferStore(A_shared) <= A[...] + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 24.000000085991324, + 24.000000085991324, + 24.000000085991324, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 25.000000042995662, + 20.000001375860553, + 23.00000017198264, + 14.000088052430122, + 1.0, + 0.0, + 0.0, + 18.00000550343433, + 20.00562591970089, + 2.321928094887362, + 23.00000017198264, + 18.00000550343433, + 21.000000687930438, + 12.0003521774803, + 12.0003521774803, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer A.shared + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 25.000000042995662, + 12.0003521774803, + 23.00000017198264, + 9.002815015607053, + 1.0, + 0.0, + 0.0, + 6.022367813028454, + 11.98049663618346, + 8.005624549193879, + 17.000011006847668, + 4.087462841250339, + 15.000044026886828, + 1.584962500721156, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 12.0003521774803, + 27.000000010748916, + 17.000011006847668, + 6.022367813028454, + 23.00000017198264, + 2.584962500721156, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[1]: BufferStore(B_shared) <= B[...] + f = feature[1] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 22.00000034396526, + 22.00000034396526, + 21.000000687930438, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 22.00000034396526, + 20.000001375860553, + 20.000001375860553, + 14.000088052430122, + 1.0, + 0.0, + 0.0, + 15.000044026886828, + 20.17555076886471, + 2.321928094887362, + 20.000001375860553, + 18.00000550343433, + 18.00000550343433, + 12.0003521774803, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer B.shared + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 22.00000034396526, + 9.002815015607053, + 20.000001375860553, + 3.169925001442312, + 1.0, + 0.0, + 0.0, + 3.169925001442312, + 10.001408194392809, + 8.005624549193879, + 14.000088052430122, + 1.584962500721156, + 12.0003521774803, + 0.044394119358453436, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 9.002815015607053, + 24.000000085991324, + 17.000011006847668, + 3.169925001442312, + 20.000001375860553, + 2.584962500721156, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[2]: BufferStore(C_local) <= C_local[...] + A_shared[...] * B_shared[...] + f = feature[2] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 27.000000010748916, + 27.000000010748916, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 28.000000005374456, + 28.000000005374456, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B.shared + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 29.00000000268723, + 9.002815015607053, + 23.00000017198264, + 3.169925001442312, + 1.0, + 0.0, + 0.0, + 5.044394119358453, + 7.651051691178929, + 5.044394119358453, + 24.000000085991324, + 4.087462841250339, + 18.00000550343433, + 0.32192809488736235, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C.local + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 0.0, + 1.0, + 29.00000000268723, + 11.000704269011246, + 23.00000017198264, + 5.044394119358453, + 1.0, + 0.0, + 0.0, + 4.087462841250339, + 7.05528243550119, + 1.584962500721156, + 28.000000005374456, + 10.001408194392809, + 22.00000034396526, + 4.087462841250339, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Buffer A.shared + assert_allclose( + actual=f[93:111], + desired=[ + 1.0, + 0.0, + 0.0, + 29.00000000268723, + 12.0003521774803, + 19.00000275171979, + 9.002815015607053, + 1.0, + 0.0, + 0.0, + 1.0, + 3.700439718141092, + 4.087462841250339, + 25.000000042995662, + 8.005624549193879, + 15.000044026886828, + 5.044394119358453, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[ + 0.7097842504665767, + 0.7548801745187567, + 0.8775907547541741, + 0.9957389916154509, + 1.2446737395193135, + 1.493608487423176, + 1.7093103019954263, + 1.8031580276850985, + 1.9841832691827785, + 2.204648076869754, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 11.000704269011246, + 18.00000550343433, + 9.002815015607053, + 18.00000550343433, + 27.000000010748916, + 3.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[3]: BufferStore(C) <= C_local[...] + f = feature[3] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer C + assert_allclose( + actual=f[57:75], + desired=[ + 0.0, + 1.0, + 0.0, + 20.000001375860553, + 20.000001375860553, + 14.000088052430122, + 14.000088052430122, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 21.000000687930438, + 21.000000687930438, + 15.000044026886828, + 15.000044026886828, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C.local + assert_allclose( + actual=f[75:93], + desired=[ + 1.0, + 0.0, + 0.0, + 20.000001375860553, + 11.000704269011246, + 14.000088052430122, + 5.044394119358453, + 1.0, + 0.0, + 0.0, + 9.002815015607053, + 12.0003521774803, + 4.087462841250339, + 16.00002201361136, + 7.011227255423254, + 10.001408194392809, + 1.584962500721156, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 20.000001375860553, + 18.00000550343433, + 1.0, + 18.00000550343433, + 18.00000550343433, + 2.584962500721156, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_fusion() + test_gpu() diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py new file mode 100644 index 0000000000..b36d6ca7cf --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import re +from typing import List + +import pytest +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.meta_schedule.builder import BuilderResult +from tvm.meta_schedule.measure_callback import PyMeasureCallback +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.task_scheduler.task_scheduler import TaskScheduler +from tvm.meta_schedule.utils import _get_hex_address +from tvm.script import tir as T +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_measure_callback(): + class FancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + task_id: int, + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> None: + assert len(measure_candidates) == 1 + assert_structural_equal(measure_candidates[0].sch.mod, Matmul) + assert ( + len(builds) == 1 + and builds[0].error_msg is None + and builds[0].artifact_path == "test_build" + ) + assert ( + len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2 + ) + + measure_callback = FancyMeasureCallback() + measure_callback.apply( + TaskScheduler(), + 0, + [MeasureCandidate(Schedule(Matmul), None)], + [BuilderResult("test_build", None)], + [RunnerResult([1.0, 2.1], None)], + ) + + +def test_meta_schedule_measure_callback_fail(): + class FailingMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + task_id: int, + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> None: + raise ValueError("test") + + measure_callback = FailingMeasureCallback() + with pytest.raises(ValueError, match="test"): + measure_callback.apply( + TaskScheduler(), + 0, + [MeasureCandidate(Schedule(Matmul), None)], + [BuilderResult("test_build", None)], + [RunnerResult([1.0, 2.1], None)], + ) + + +def test_meta_schedule_measure_callback_as_string(): + class NotSoFancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: "TaskScheduler", + task_id: int, + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> None: + pass + + def __str__(self) -> str: + return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})" + + measure_callback = NotSoFancyMeasureCallback() + pattern = re.compile(r"NotSoFancyMeasureCallback\(0x[a-f|0-9]*\)") + assert pattern.match(str(measure_callback)) + + +if __name__ == "__main__": + test_meta_schedule_measure_callback() + test_meta_schedule_measure_callback_fail() + test_meta_schedule_measure_callback_as_string() diff --git a/tests/python/unittest/test_meta_schedule_mutator.py b/tests/python/unittest/test_meta_schedule_mutator.py new file mode 100644 index 0000000000..b4d94dc9a8 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List, Optional + +import re + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.script import tir as T + +from tvm.meta_schedule.mutator import PyMutator +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import Schedule, Trace + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_mutator(): + class FancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + mutator = FancyMutator() + sch = Schedule(Matmul) + res = mutator.apply(sch.trace) + assert res is not None + new_sch = sch.copy() + res.apply_to_schedule(new_sch, remove_postproc=True) + assert_structural_equal(sch.mod, new_sch.mod) + + +def test_meta_schedule_mutator_as_string(): + class YetAnotherFancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + pass + + def __str__(self) -> str: + return f"YetAnotherFancyMutator({_get_hex_address(self.handle)})" + + mutator = YetAnotherFancyMutator() + pattern = re.compile(r"YetAnotherFancyMutator\(0x[a-f|0-9]*\)") + assert pattern.match(str(mutator)) + + +if __name__ == "__main__": + test_meta_schedule_mutator() + test_meta_schedule_mutator_as_string() diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py new file mode 100644 index 0000000000..685cfe1017 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateParallel, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.parallel", ann_val=ann_val) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator: + mutator = MutateParallel(max_jobs_per_core) + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_parallel_matmul(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + max_jobs_per_core=256, + ) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ann_val=64, + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + ann_val = int(trace.insts[-1].inputs[1]) + results.add(ann_val) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {4, 32, 4096} + + +if __name__ == """__main__""": + test_mutate_parallel_matmul() diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py new file mode 100644 index 0000000000..d30ec8bb99 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateTileSize, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]]) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + (d0,) = decisions + b0 = sch.get_block(name="C", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2]) + l23, l24 = sch.split(loop=l4, factors=[512, 1]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateTileSize() + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_tile_size_matmul(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + ) + results = {} + sch = _sch(decisions=[[4, 32, 4, 1]]) + for _ in range(100): + trace = mutator.apply(sch.trace) + assert trace.insts[4].kind.name == "SamplePerfectTile" + decision = trace.decisions[trace.insts[4]] + decision = [int(x) for x in decision] + results[str(decision)] = decision + assert math.prod(decision) == 512 + assert len(results) > 15 + + +if __name__ == """__main__""": + test_mutate_tile_size_matmul() diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py new file mode 100644 index 0000000000..d51b70c78f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateUnroll, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]]) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1) + v57 = sch.sample_categorical( + candidates=[0, 16, 64, 512], + probs=[0.25, 0.25, 0.25, 0.25], + decision=0, + ) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.unroll_explicit", ann_val=v57) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateUnroll() + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_unroll_matmul(): + mutator = _make_mutator(target=Target("llvm --num-cores=16")) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + decision = trace.decisions[trace.insts[-2]] + results.add(decision) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {1, 2, 3} + + +if __name__ == """__main__""": + test_mutate_unroll_matmul() diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py new file mode 100644 index 0000000000..b78e67817e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -0,0 +1,342 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import math +import sys +from typing import List + +import pytest +import tvm +from tvm.error import TVMError +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir.schedule import BlockRV, Schedule + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DuplicateMatmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class TrinityMatmul: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.alloc_buffer((1024, 1024), "float32") + C = T.alloc_buffer((1024, 1024), "float32") + D = T.match_buffer(d, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1024, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 3.0 + for i, j in T.grid(1024, 1024): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = C[vi, vj] * 5.0 + + +@tvm.script.ir_module +class TrinityMatmulProcessedForReference: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024], dtype="float32") + D = T.match_buffer(d, [1024, 1024], dtype="float32") + # body + # with tir.block("root") + B = T.alloc_buffer([1024, 1024], dtype="float32") + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("A"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("C"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([B[vi, vj]]) + T.writes([D[vi, vj]]) + D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _is_root(sch: Schedule, block: BlockRV) -> bool: + return sch.get_sref(block).parent is None + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +class WowSoFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [new_sch] + + +class DoubleScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result = [new_sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result.append(new_sch) + return result + + +class ReorderScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, i_0, j_0) + result = [new_sch] + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_3, i_0, j_0, j_1, k_0, i_2, j_2, k_1, i_3) + result.append(new_sch) + return result + + +def test_meta_schedule_post_order_apply(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Test Task", + sch_rules=[WowSoFancyScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + assert not tvm.ir.structural_equal(schs[0].mod, mod) + _check_correct(schs[0]) + + +def test_meta_schedule_post_order_apply_double(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 2 + for sch in schs: + assert not tvm.ir.structural_equal(sch.mod, mod) + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_multiple(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + assert not tvm.ir.structural_equal(sch.mod, mod) + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_duplicate_matmul(): + mod = DuplicateMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Duplicate Matmul Task", + sch_rules=[WowSoFancyScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises( + TVMError, + match=r".*TVMError: Check failed: \(block_names_.count\(block->name_hint\) == 0\)" + r" is false: Duplicated block name matmul in function main not supported!", + ): + post_order_apply.generate_design_space(mod) + + +def test_meta_schedule_post_order_apply_remove_block(): + class TrinityDouble(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) + j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result = [new_sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) + j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result.append(new_sch) + return result + + class RemoveBlock(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + sch = sch.copy() + if sch.get(block).name_hint == "B": + sch.compute_inline(block) + return [sch] + + def correct_trace(a, b, c, d): + return "\n".join( + [ + 'b0 = sch.get_block(name="A", func_name="main")', + 'b1 = sch.get_block(name="B", func_name="main")', + 'b2 = sch.get_block(name="C", func_name="main")', + "sch.compute_inline(block=b1)", + "l3, l4 = sch.get_loops(block=b2)", + "l5, l6 = sch.split(loop=l3, factors=" + str(a) + ")", + "l7, l8 = sch.split(loop=l4, factors=" + str(b) + ")", + "sch.reorder(l5, l7, l6, l8)", + "l9, l10 = sch.get_loops(block=b0)", + "l11, l12 = sch.split(loop=l9, factors=" + str(c) + ")", + "l13, l14 = sch.split(loop=l10, factors=" + str(d) + ")", + "sch.reorder(l11, l13, l12, l14)", + ] + ) + + mod = TrinityMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Remove Block Task", + sch_rules=[RemoveBlock(), TrinityDouble()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + with pytest.raises( + tvm.tir.schedule.schedule.ScheduleError, + match="ScheduleError: An error occurred in the schedule primitive 'get-block'.", + ): + sch.get_block("B", "main") + sch_trace = sch.trace.simplified(True) + assert ( + str(sch_trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512]) + or str(sch_trace) == correct_trace([2, 512], [2, 512], [2, 512], [2, 512]) + or str(sch_trace) == correct_trace([16, 64], [64, 16], [16, 64], [64, 16]) + or str(sch_trace) == correct_trace([2, 512], [2, 512], [16, 64], [64, 16]) + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py new file mode 100644 index 0000000000..6e17e7bac3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import PyPostproc +from tvm.meta_schedule.utils import _get_hex_address +from tvm.script import tir as T +from tvm.target.target import Target +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_postproc(): + class FancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + schedule_matmul(sch) + return True + + postproc = FancyPostproc() + mod = Matmul + sch = Schedule(mod) + assert postproc.apply(sch) + try: + tvm.ir.assert_structural_equal(sch.mod, mod) + raise Exception("The postprocessors did not change the schedule.") + except ValueError: + _check_correct(sch) + + +def test_meta_schedule_postproc_fail(): + class FailingPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + return False + + postproc = FailingPostproc() + sch = Schedule(Matmul) + assert not postproc.apply(sch) + + +def test_meta_schedule_postproc_as_string(): + class NotSoFancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + pass + + def __str__(self) -> str: + return f"NotSoFancyPostproc({_get_hex_address(self.handle)})" + + postproc = NotSoFancyPostproc() + pattern = re.compile(r"NotSoFancyPostproc\(0x[a-f|0-9]*\)") + assert pattern.match(str(postproc)) + + +if __name__ == "__main__": + test_meta_schedule_postproc() + test_meta_schedule_postproc_fail() + test_meta_schedule_postproc_as_string() diff --git a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py new file mode 100644 index 0000000000..d27e3e6108 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import DisallowDynamicLoop +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + DisallowDynamicLoop(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DynamicLoop: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + for k in T.serial(0, i): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_postproc_disallow_dynamic_loops(): + mod = Matmul + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_disallow_dynamic_loops_fail(): + mod = DynamicLoop + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + test_postproc_disallow_dynamic_loops() + test_postproc_disallow_dynamic_loops_fail() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py new file mode 100644 index 0000000000..ec40c592a8 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteCooperativeFetch +from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteCooperativeFetch(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class AfterRewrite: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + target = _target() + ctx = _create_context(mod, target) + + sch = tir.Schedule(mod, debug_mask="all") + # fmt: off + # pylint: disable=line-too-long,invalid-name + b0 = sch.get_block(name="C", func_name="main") + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) + l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9]) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2]) + l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19]) + v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32]) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27]) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22) + sch.bind(loop=l33, thread_axis="threadIdx.x") + b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1) + _, _, _, _, l39, l40 = sch.get_loops(block=b34) + l41 = sch.fuse(l39, l40) + _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1]) + sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1) + _, _, _, _, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50) + _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2]) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53) + sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1) + # pylint: enable=line-too-long,invalid-name + # fmt: on + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, AfterRewrite) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py new file mode 100644 index 0000000000..b7f8f507d3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Move_PUV: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1) + T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + + + +@T.prim_func +def Move_PUV0(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + for i0_j0_fused in T.parallel(0, 8192): + for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): + for k1_fused in T.vectorized(0, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1_fused) + T.where( + (i0_j0_fused // 64 % 128 * 4 + i1) * 4 + i2 < 1024 + and (i0_j0_fused % 64 * 4 + j1) * 8 + j2 < 1024 + and k0 * 32 + k1_fused % 32 < 1024 + ) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): + postproc = RewriteParallelVectorizeUnroll() + sch = Schedule(Move_PUV) + assert postproc.apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], Move_PUV0) + + +if __name__ == "__main__": + test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py new file mode 100644 index 0000000000..93ea76ec5d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteReductionBlock +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteReductionBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(2, 2, 16, 2): + with T.block("C_init"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3_init * 16 + i0_4_init) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3_init * 2 + i1_4_init) + T.reads([]) + T.writes([C_local[i, j]]) + C_local[i, j] = T.float32(0) + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C_update"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_reduction_block(): + mod = Before + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After) + + +if __name__ == "__main__": + test_rewrite_reduction_block() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py new file mode 100644 index 0000000000..8b062a11b5 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteUnboundBlock +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteUnboundBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i, j in T.grid(512, 512): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i_j_fused_0 in T.thread_binding(0, 8192, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("C"): + vi = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) // 512) + vj = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) % 512) + B[vi, vj] = A[vi, vj] + 1.0 + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = Before + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch() diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py new file mode 100644 index 0000000000..cdebcddf5d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -0,0 +1,232 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import VerifyGPUCode +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("nvidia/geforce-rtx-3080") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + VerifyGPUCode(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Conv2dCuda0: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda1: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([6400000], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda2: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512000], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda3: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 800000) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant + + +def test_postproc_verify_gpu_0(): + mod = Conv2dCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_1(): + mod = Conv2dCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_2(): + mod = Conv2dCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_3(): + mod = Conv2dCuda3 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + test_postproc_verify_gpu_0() + test_postproc_verify_gpu_1() + test_postproc_verify_gpu_2() + test_postproc_verify_gpu_3() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule.py new file mode 100644 index 0000000000..1d34d94bfe --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re +from typing import List + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.script import tir as T +from tvm.tir.schedule import BlockRV, Schedule + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def test_meta_schedule_schedule_rule(): + class FancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [sch] + + sch_rule = FancyScheduleRule() + mod = Matmul + sch = Schedule(mod) + res = sch_rule.apply(sch, block=sch.get_block("matmul")) + assert len(res) == 1 + try: + tvm.ir.assert_structural_equal(mod, res[0].mod) + raise Exception("The schedule rule did not change the schedule.") + except ValueError: + _check_correct(res[0]) + + +def test_meta_schedule_schedule_rule_as_string(): + class YetStillSomeFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + pass + + sch_rule = YetStillSomeFancyScheduleRule() + pattern = re.compile(r"YetStillSomeFancyScheduleRule\(0x[a-f|0-9]*\)") + assert pattern.match(str(sch_rule)) + + +if __name__ == "__main__": + test_meta_schedule_schedule_rule() + test_meta_schedule_schedule_rule_as_string() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py new file mode 100644 index 0000000000..0273a2bdf6 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import ( + auto_inline, + auto_inline_after_tiling, +) +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Conv2DBiasBnReLU: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads([compute_1[nn, ff, yy, xx], pad_temp[nn, rc, yy + ry, xx + rx], W[ff, rc, ry, rx]]) + T.writes([compute_1[nn, ff, yy, xx]]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bias_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i, j, k, l], B[j, 0, 0]]) + T.writes([bias_add[i, j, k, l]]) + bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_mul"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bias_add[i, j, k, l], bn_scale[j, 0, 0]]) + T.writes([bn_mul[i, j, k, l]]) + bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bn_mul[i, j, k, l], bn_offset[j, 0, 0]]) + T.writes([bn_add[i, j, k, l]]) + bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bn_add[i0_2, i1_2, i2_2, i3_2]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0)) + + +@tvm.script.ir_module +class Conv2DBiasBnReLUInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads([compute_1[nn, ff, yy, xx], pad_temp[nn, rc, yy + ry, xx + rx], W[ff, rc, ry, rx]]) + T.writes([compute_1[nn, ff, yy, xx]]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i0_2, i1_2, i2_2, i3_2], B[i1_2, 0, 0], bn_scale[i1_2, 0, 0], bn_offset[i1_2, 0, 0]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class NeedsInlinePaddingAndEpilogue: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + T.reads([pad_temp[v0, v1, v2, v3]]) + T.writes([pad_temp_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + T.reads([W[v0, v1, v2, v3]]) + T.writes([W_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + T.reads([compute_local[nn, ff, yy, xx], pad_temp_shared[nn, rc, yy + ry, xx + rx], W_shared[ff, rc, ry, rx]]) + T.writes([compute_local[nn, ff, yy, xx]]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + T.block_attr({"meta_schedule.cache_type":1}) + T.reads([compute_local[v0, v1, v2, v3]]) + T.writes([compute_1[v0, v1, v2, v3]]) + compute_1[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i0_2, i1_2, i2_2, i3_2], B[i1_2, 0, 0], bn_scale[i1_2, 0, 0], bn_offset[i1_2, 0, 0]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class PaddingAndEpilogueInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + T.reads([X[v0, v1, v2 - 1, v3 - 1]]) + T.writes([pad_temp_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + pad_temp_shared[v0, v1, v2, v3] = T.if_then_else(v2 >= 1 and v2 < 57 and v3 >= 1 and v3 < 57, X[v0, v1, v2 - 1, v3 - 1], T.float32(0), dtype="float32") + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + T.reads([W[v0, v1, v2, v3]]) + T.writes([W_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + T.reads([compute_local[nn, ff, yy, xx], pad_temp_shared[nn, rc, yy + ry, xx + rx], W_shared[ff, rc, ry, rx]]) + T.writes([compute_local[nn, ff, yy, xx]]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + T.reads([compute_local[v0, v1, v2, v3], B[v1, 0, 0], bn_scale[v1, 0, 0], bn_offset[v1, 0, 0]]) + T.writes([compute[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":1}) + compute[v0, v1, v2, v3] = T.max((compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0)) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_inline_consumer_chain(): + mod = Conv2DBiasBnReLU + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) + + +def test_inline_into_cache(): + mod = NeedsInlinePaddingAndEpilogue + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=NeedsInlinePaddingAndEpilogue, + target=target, + rule=auto_inline_after_tiling(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=PaddingAndEpilogueInlined) + + +if __name__ == "__main__": + test_inline_consumer_chain() + test_inline_into_cache() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py new file mode 100644 index 0000000000..240e4eb86f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import multi_level_tiling +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.te import create_prim_func +from tvm.meta_schedule.testing import te_workload +from tvm.target import Target + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cpu_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + # pylint: enable=line-too-long + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cuda_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", + ] + ] + # pylint: enable=line-too-long + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=1)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", + ] + ] + # pylint: enable=line-too-long + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_matmul_relu() + test_cuda_matmul() + test_cuda_matmul_relu() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py new file mode 100644 index 0000000000..e57799f604 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class ParallelizeVectorizeUnroll: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2}) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_parallel_vectorize_unroll(): + expected = [ + [ + 'b0 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=512)', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)', + "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)', + ] + ] + mod = Matmul + target = Target("llvm --num-cores=32") + ctx = _create_context( + mod=mod, + target=target, + rule=parallel_vectorize_unroll(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_parallel_vectorize_unroll() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py new file mode 100644 index 0000000000..9f1c8d7842 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import random_compute_location +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Add: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32") + A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") + # body + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("move"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([A_cached[vi, vj, vk]]) + A_cached[vi, vj, vk] = A[vi, vj, vk] + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("add"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(2048, k0 * 32 + k1) + T.reads([A_cached[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_random_compute_location(): + expected = [ + [ + 'b0 = sch.get_block(name="move", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + ] + ] + mod = Add + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=random_compute_location(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_random_compute_location() diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 9b3ddfd7c7..4e51d497d0 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -16,17 +16,28 @@ # under the License. """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring -from typing import List +from typing import List, Tuple, Union, Optional import sys - +import numpy as np import pytest import tvm + +from tvm.ir import IRModule from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator.mutator import PyMutator from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.search_strategy import ( + SearchStrategy, + ReplayTrace, + ReplayFunc, + EvolutionarySearch, + MeasureCandidate, +) +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord +from tvm.meta_schedule.cost_model import PyCostModel from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -34,7 +45,7 @@ MATMUL_M = 32 -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# pylint: disable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking # fmt: off @tvm.script.ir_module @@ -53,48 +64,202 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: - trace_1 = Trace(sch_1.trace.insts, {}) - trace_2 = Trace(sch_2.trace.insts, {}) +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: + if remove_decisions: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + else: + trace_1 = sch_1.trace + trace_2 = sch_2.trace return str(trace_1) == str(trace_2) def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) - # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming - i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) - j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) - k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4)) + j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4)) + k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2)) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -def test_meta_schedule_replay_trace(): +@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) +def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name num_trials_per_iter = 7 num_trials_total = 20 - (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul) - replay.initialize_with_tune_context(tune_context) - - num_trials_each_round: List[int] = [] - replay.pre_tuning([example_sch]) - while True: - candidates = replay.generate_measure_candidates() - if candidates is None: - break - num_trials_each_round.append(len(candidates)) + strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + assert num_trials_each_iter == [7, 7, 6] + + +def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name + class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + class DummyDatabase(PyDatabase): + """Dummy Database for testing""" + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + class RandomModel(PyCostModel): + """Random cost model for testing""" + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, file_location: str) -> None: + self.random_state = tuple(np.load(file_location, allow_pickle=True)) + + def save(self, file_location: str) -> None: + np.save(file_location, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + np.random.set_state(self.random_state) + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result + + num_trials_per_iter = 10 + num_trials_total = 100 + + mutator = DummyMutator() + database = DummyDatabase() + cost_model = RandomModel() + strategy = EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + num_trials_total=num_trials_total, + population=5, + init_measured_ratio=0.1, + genetic_algo_iters=3, + max_evolve_fail_cnt=10, + p_mutate=0.5, + eps_greedy=0.9, + database=database, + cost_model=cost_model, + ) + tune_context = TuneContext( + mod=Matmul, + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + mutator_probs={mutator: 1.0}, + target=tvm.target.Target("llvm"), + num_threads=1, # beacuse we are using a mutator from the python side + ) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: - assert _is_trace_equal(candidate.sch, example_sch) - runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) - replay.notify_runner_results(runner_results) - replay.post_tuning() - assert num_trials_each_round == [7, 7, 6] + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + print(num_trials_each_iter) + correct_count = 10 # For each iteration except the last one + assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( + [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_sketch_cpu.py b/tests/python/unittest/test_meta_schedule_sketch_cpu.py new file mode 100644 index 0000000000..065ccf75b3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cpu.py @@ -0,0 +1,430 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("llvm --num-cores=16") + + +def test_meta_schedule_cpu_sketch_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l18, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l19, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "b2, = sch.get_consumers(block=b0)", + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l18, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "b2, = sch.get_consumers(block=b0)", + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l19, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "l2, l3, l4, l5, l6, l7, l8 = sch.get_loops(block=b0)", + "v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l13, l14, l15, l16 = sch.split(loop=l2, factors=[v9, v10, v11, v12])", + "v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l21, l22, l23, l24 = sch.split(loop=l3, factors=[v17, v18, v19, v20])", + "v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l29, l30, l31, l32 = sch.split(loop=l4, factors=[v25, v26, v27, v28])", + "v33, v34, v35, v36 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l37, l38, l39, l40 = sch.split(loop=l5, factors=[v33, v34, v35, v36])", + "v41, v42 = sch.sample_perfect_tile(loop=l6, n=2, max_innermost_factor=64)", + "l43, l44 = sch.split(loop=l6, factors=[v41, v42])", + "v45, v46 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l47, l48 = sch.split(loop=l7, factors=[v45, v46])", + "v49, v50 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l51, l52 = sch.split(loop=l8, factors=[v49, v50])", + "sch.reorder(l13, l21, l29, l37, l14, l22, l30, l38, l43, l47, l51, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v53 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53)', + ], + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b0)", + "v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l14, l15, l16, l17 = sch.split(loop=l3, factors=[v10, v11, v12, v13])", + "v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l22, l23, l24, l25 = sch.split(loop=l4, factors=[v18, v19, v20, v21])", + "v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l30, l31, l32, l33 = sch.split(loop=l5, factors=[v26, v27, v28, v29])", + "v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l38, l39, l40, l41 = sch.split(loop=l6, factors=[v34, v35, v36, v37])", + "v42, v43 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l44, l45 = sch.split(loop=l7, factors=[v42, v43])", + "v46, v47 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l8, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l9, factors=[v50, v51])", + "sch.reorder(l14, l22, l30, l38, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41)", + "sch.reverse_compute_at(block=b2, loop=l38, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v54)', + ], + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b0)", + "v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l14, l15, l16, l17 = sch.split(loop=l3, factors=[v10, v11, v12, v13])", + "v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l22, l23, l24, l25 = sch.split(loop=l4, factors=[v18, v19, v20, v21])", + "v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l30, l31, l32, l33 = sch.split(loop=l5, factors=[v26, v27, v28, v29])", + "v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l38, l39, l40, l41 = sch.split(loop=l6, factors=[v34, v35, v36, v37])", + "v42, v43 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l44, l45 = sch.split(loop=l7, factors=[v42, v43])", + "v46, v47 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l8, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l9, factors=[v50, v51])", + "sch.reorder(l14, l22, l30, l38, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41)", + "sch.reverse_compute_at(block=b2, loop=l39, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v54)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="bias_add", func_name="main")', + 'b2 = sch.get_block(name="bn_mul", func_name="main")', + 'b3 = sch.get_block(name="bn_add", func_name="main")', + 'b4 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + "sch.compute_inline(block=b1)", + "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b0)", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l5, factors=[v12, v13, v14, v15])", + "v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l24, l25, l26, l27 = sch.split(loop=l6, factors=[v20, v21, v22, v23])", + "v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l32, l33, l34, l35 = sch.split(loop=l7, factors=[v28, v29, v30, v31])", + "v36, v37, v38, v39 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l40, l41, l42, l43 = sch.split(loop=l8, factors=[v36, v37, v38, v39])", + "v44, v45 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l46, l47 = sch.split(loop=l9, factors=[v44, v45])", + "v48, v49 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l50, l51 = sch.split(loop=l10, factors=[v48, v49])", + "v52, v53 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l54, l55 = sch.split(loop=l11, factors=[v52, v53])", + "sch.reorder(l16, l24, l32, l40, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42, l47, l51, l55, l19, l27, l35, l43)", + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.vectorize", ann_val=32)', + "v56 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v56)', + ], + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="bias_add", func_name="main")', + 'b2 = sch.get_block(name="bn_mul", func_name="main")', + 'b3 = sch.get_block(name="bn_add", func_name="main")', + 'b4 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + "sch.compute_inline(block=b1)", + "b5, = sch.get_consumers(block=b0)", + "l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b0)", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l6, factors=[v13, v14, v15, v16])", + "v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l25, l26, l27, l28 = sch.split(loop=l7, factors=[v21, v22, v23, v24])", + "v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l33, l34, l35, l36 = sch.split(loop=l8, factors=[v29, v30, v31, v32])", + "v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l41, l42, l43, l44 = sch.split(loop=l9, factors=[v37, v38, v39, v40])", + "v45, v46 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l47, l48 = sch.split(loop=l10, factors=[v45, v46])", + "v49, v50 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l51, l52 = sch.split(loop=l11, factors=[v49, v50])", + "v53, v54 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l55, l56 = sch.split(loop=l12, factors=[v53, v54])", + "sch.reorder(l17, l25, l33, l41, l18, l26, l34, l42, l47, l51, l55, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44)", + "sch.reverse_compute_at(block=b5, loop=l41, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.vectorize", ann_val=32)', + "v57 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v57)', + ], + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.get_block(name="bias_add", func_name="main")', + 'b2 = sch.get_block(name="bn_mul", func_name="main")', + 'b3 = sch.get_block(name="bn_add", func_name="main")', + 'b4 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + "sch.compute_inline(block=b1)", + "b5, = sch.get_consumers(block=b0)", + "l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b0)", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l6, factors=[v13, v14, v15, v16])", + "v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l25, l26, l27, l28 = sch.split(loop=l7, factors=[v21, v22, v23, v24])", + "v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l33, l34, l35, l36 = sch.split(loop=l8, factors=[v29, v30, v31, v32])", + "v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l41, l42, l43, l44 = sch.split(loop=l9, factors=[v37, v38, v39, v40])", + "v45, v46 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l47, l48 = sch.split(loop=l10, factors=[v45, v46])", + "v49, v50 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l51, l52 = sch.split(loop=l11, factors=[v49, v50])", + "v53, v54 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l55, l56 = sch.split(loop=l12, factors=[v53, v54])", + "sch.reorder(l17, l25, l33, l41, l18, l26, l34, l42, l47, l51, l55, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44)", + "sch.reverse_compute_at(block=b5, loop=l42, preserve_unit_loops=1)", + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.vectorize", ann_val=32)', + "v57 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v57)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): + # pylint: disable=line-too-long + expected: List[List[str]] = [ + [ + 'b0 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)', + "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.max_pool2d_nchw( + n=1, + h=56, + w=56, + ci=512, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cpu_sketch_matmul() + test_meta_schedule_cpu_sketch_matmul_relu() + test_meta_schedule_cpu_sketch_conv2d_nchw() + test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu() + test_meta_schedule_sketch_cpu_max_pool2d_nchw() diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py new file mode 100644 index 0000000000..c9e645a778 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def test_meta_schedule_cuda_sketch_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l11, l12, l13, l14, l15 = sch.split(loop=l3, factors=[v6, v7, v8, v9, v10])", + "v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l21, l22, l23, l24, l25 = sch.split(loop=l4, factors=[v16, v17, v18, v19, v20])", + "v26, v27, v28 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64)", + "l29, l30, l31 = sch.split(loop=l5, factors=[v26, v27, v28])", + "sch.reorder(l11, l21, l12, l22, l13, l23, l29, l30, l14, l24, l31, l15, l25)", + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="blockIdx.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="vthread.x")', + "l34 = sch.fuse(l13, l23)", + 'sch.bind(loop=l34, thread_axis="threadIdx.x")', + 'b35 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b35, loop=l29, preserve_unit_loops=1)", + "l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35)", + "l42 = sch.fuse(l40, l41)", + "v43, v44 = sch.sample_perfect_tile(loop=l42, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)', + 'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b45, loop=l29, preserve_unit_loops=1)", + "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)", + "l52 = sch.fuse(l50, l51)", + "v53, v54 = sch.sample_perfect_tile(loop=l52, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v54)', + "sch.reverse_compute_at(block=b2, loop=l34, preserve_unit_loops=1)", + "v55 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v55)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'b3 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l4, l5, l6 = sch.get_loops(block=b0)", + "v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l12, l13, l14, l15, l16 = sch.split(loop=l4, factors=[v7, v8, v9, v10, v11])", + "v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l22, l23, l24, l25, l26 = sch.split(loop=l5, factors=[v17, v18, v19, v20, v21])", + "v27, v28, v29 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64)", + "l30, l31, l32 = sch.split(loop=l6, factors=[v27, v28, v29])", + "sch.reorder(l12, l22, l13, l23, l14, l24, l30, l31, l15, l25, l32, l16, l26)", + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="blockIdx.x")', + "l34 = sch.fuse(l13, l23)", + 'sch.bind(loop=l34, thread_axis="vthread.x")', + "l35 = sch.fuse(l14, l24)", + 'sch.bind(loop=l35, thread_axis="threadIdx.x")', + 'b36 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b36, loop=l30, preserve_unit_loops=1)", + "l37, l38, l39, l40, l41, l42 = sch.get_loops(block=b36)", + "l43 = sch.fuse(l41, l42)", + "v44, v45 = sch.sample_perfect_tile(loop=l43, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b36, ann_key="meta_schedule.cooperative_fetch", ann_val=v45)', + 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b46, loop=l30, preserve_unit_loops=1)", + "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", + "l53 = sch.fuse(l51, l52)", + "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)', + "sch.reverse_compute_at(block=b3, loop=l35, preserve_unit_loops=1)", + "sch.reverse_compute_inline(block=b1)", + "v56 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v56)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b1)", + "v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l16, l17, l18, l19, l20 = sch.split(loop=l4, factors=[v11, v12, v13, v14, v15])", + "v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l26, l27, l28, l29, l30 = sch.split(loop=l5, factors=[v21, v22, v23, v24, v25])", + "v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64)", + "l36, l37, l38, l39, l40 = sch.split(loop=l6, factors=[v31, v32, v33, v34, v35])", + "v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)", + "l46, l47, l48, l49, l50 = sch.split(loop=l7, factors=[v41, v42, v43, v44, v45])", + "v51, v52, v53 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64)", + "l54, l55, l56 = sch.split(loop=l8, factors=[v51, v52, v53])", + "v57, v58, v59 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)", + "l60, l61, l62 = sch.split(loop=l9, factors=[v57, v58, v59])", + "v63, v64, v65 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64)", + "l66, l67, l68 = sch.split(loop=l10, factors=[v63, v64, v65])", + "sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l60, l66, l55, l61, l67, l19, l29, l39, l49, l56, l62, l68, l20, l30, l40, l50)", + "l69 = sch.fuse(l16, l26, l36, l46)", + 'sch.bind(loop=l69, thread_axis="blockIdx.x")', + "l70 = sch.fuse(l17, l27, l37, l47)", + 'sch.bind(loop=l70, thread_axis="vthread.x")', + "l71 = sch.fuse(l18, l28, l38, l48)", + 'sch.bind(loop=l71, thread_axis="threadIdx.x")', + 'b72 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b72, loop=l66, preserve_unit_loops=1)", + "l73, l74, l75, l76, l77, l78, l79, l80, l81, l82 = sch.get_loops(block=b72)", + "l83 = sch.fuse(l79, l80, l81, l82)", + "v84, v85 = sch.sample_perfect_tile(loop=l83, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v85)', + 'b86 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b86, loop=l66, preserve_unit_loops=1)", + "l87, l88, l89, l90, l91, l92, l93, l94, l95, l96 = sch.get_loops(block=b86)", + "l97 = sch.fuse(l93, l94, l95, l96)", + "v98, v99 = sch.sample_perfect_tile(loop=l97, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b86, ann_key="meta_schedule.cooperative_fetch", ann_val=v99)', + "sch.reverse_compute_at(block=b3, loop=l71, preserve_unit_loops=1)", + "sch.compute_inline(block=b0)", + "v100 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v100)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="compute_1", func_name="main")', + 'b6 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + 'b7 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l8, l9, l10, l11, l12, l13, l14 = sch.get_loops(block=b1)", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l8, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64)", + "l30, l31, l32, l33, l34 = sch.split(loop=l9, factors=[v25, v26, v27, v28, v29])", + "v35, v36, v37, v38, v39 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64)", + "l40, l41, l42, l43, l44 = sch.split(loop=l10, factors=[v35, v36, v37, v38, v39])", + "v45, v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l11, n=5, max_innermost_factor=64)", + "l50, l51, l52, l53, l54 = sch.split(loop=l11, factors=[v45, v46, v47, v48, v49])", + "v55, v56, v57 = sch.sample_perfect_tile(loop=l12, n=3, max_innermost_factor=64)", + "l58, l59, l60 = sch.split(loop=l12, factors=[v55, v56, v57])", + "v61, v62, v63 = sch.sample_perfect_tile(loop=l13, n=3, max_innermost_factor=64)", + "l64, l65, l66 = sch.split(loop=l13, factors=[v61, v62, v63])", + "v67, v68, v69 = sch.sample_perfect_tile(loop=l14, n=3, max_innermost_factor=64)", + "l70, l71, l72 = sch.split(loop=l14, factors=[v67, v68, v69])", + "sch.reorder(l20, l30, l40, l50, l21, l31, l41, l51, l22, l32, l42, l52, l58, l64, l70, l59, l65, l71, l23, l33, l43, l53, l60, l66, l72, l24, l34, l44, l54)", + "l73 = sch.fuse(l20, l30, l40, l50)", + 'sch.bind(loop=l73, thread_axis="blockIdx.x")', + "l74 = sch.fuse(l21, l31, l41, l51)", + 'sch.bind(loop=l74, thread_axis="vthread.x")', + "l75 = sch.fuse(l22, l32, l42, l52)", + 'sch.bind(loop=l75, thread_axis="threadIdx.x")', + 'b76 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b76, loop=l70, preserve_unit_loops=1)", + "l77, l78, l79, l80, l81, l82, l83, l84, l85, l86 = sch.get_loops(block=b76)", + "l87 = sch.fuse(l83, l84, l85, l86)", + "v88, v89 = sch.sample_perfect_tile(loop=l87, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b76, ann_key="meta_schedule.cooperative_fetch", ann_val=v89)', + 'b90 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b90, loop=l70, preserve_unit_loops=1)", + "l91, l92, l93, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b90)", + "l101 = sch.fuse(l97, l98, l99, l100)", + "v102, v103 = sch.sample_perfect_tile(loop=l101, n=2, max_innermost_factor=4)", + 'sch.annotate(block_or_loop=b90, ann_key="meta_schedule.cooperative_fetch", ann_val=v103)', + "sch.reverse_compute_at(block=b7, loop=l75, preserve_unit_loops=1)", + "sch.reverse_compute_inline(block=b5)", + "sch.compute_inline(block=b0)", + "v104 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v104)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cuda_sketch_matmul() + test_meta_schedule_cuda_sketch_matmul_relu() + test_meta_schedule_cuda_sketch_conv2d_nchw() + test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu() diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 49a3f63091..3eb050db3b 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,6 +23,9 @@ import pytest import tvm +from tvm._ffi.base import TVMError +from tvm.ir.module import IRModule +from tvm.meta_schedule.space_generator.space_generator import PySpaceGenerator from tvm.script import tir as T from tvm.tir.schedule import Schedule from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion diff --git a/tests/python/unittest/test_meta_schedule_task_extraction.py b/tests/python/unittest/test_meta_schedule_task_extraction.py new file mode 100644 index 0000000000..8d1eca5143 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_task_extraction.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import sys +from typing import Tuple + +import pytest + +import tvm +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model + + +@pytest.mark.skip("Skip because it runs too slowly as a unittest") +@pytest.mark.parametrize( + "model_name", + [ + # Image classification + "resnet50", + "alexnet", + "vgg16", + "squeezenet1_0", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "inception_v3", + "googlenet", + "shufflenet_v2_x1_0", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "resnext50_32x4d", + "wide_resnet50_2", + "mnasnet1_0", + # Segmentation + "fcn_resnet50", + "fcn_resnet101", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", + "lraspp_mobilenet_v3_large", + # Object detection + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + "maskrcnn_resnet50_fpn", + # video classification + "r3d_18", + "mc3_18", + "r2plus1d_18", + ], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("target", ["llvm", "cuda"]) +def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int, target: str): + if model_name == "inception_v3" and batch_size == 1: + pytest.skip("inception_v3 does not handle batch_size of 1") + + input_shape: Tuple[int, ...] + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + input_shape = (1, 3, 300, 300) + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + input_shape = (batch_size, 3, 3, 299, 299) + else: + raise ValueError("Unsupported model: " + model_name) + + output_shape: Tuple[int, int] = (batch_size, 1000) + mod, params = get_torch_model( + model_name=model_name, + input_shape=input_shape, + output_shape=output_shape, + dtype="float32", + ) + target = tvm.target.Target(target) + ms.integration.extract_task(mod, params=params, target=target) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index edff3552d7..d3c4dbca82 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -16,24 +16,22 @@ # under the License. """ Test Meta Schedule Task Scheduler """ -from typing import List - -import sys import random +import sys +from typing import List import pytest - import tvm -from tvm.script import tir as T from tvm.ir import IRModule -from tvm.tir import Schedule -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace -from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult -from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult +from tvm.meta_schedule import TuneContext, measure_callback +from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler +from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult +from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin +from tvm.script import tir as T +from tvm.tir import Schedule # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -140,6 +138,12 @@ def __init__(self): self.records = [] self.workload_reg = [] + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + def commit_tuning_record(self, record: TuningRecord) -> None: self.records.append(record) @@ -180,7 +184,13 @@ def test_meta_schedule_task_scheduler_single(): rand_state=42, ) database = DummyDatabase() - round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database) + round_robin = RoundRobin( + [task], + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[measure_callback.AddToDatabase()], + ) round_robin.tune() assert len(database) == num_trials_total @@ -215,15 +225,29 @@ def test_meta_schedule_task_scheduler_multiple(): ), ] database = DummyDatabase() - round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database) + round_robin = RoundRobin( + tasks, + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[measure_callback.AddToDatabase()], + ) round_robin.tune() assert len(database) == num_trials_total * len(tasks) print(database.workload_reg) for task in tasks: - assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + assert ( + len( + database.get_top_k( + database.commit_workload(task.mod), + 100000, + ) + ) + == num_trials_total + ) -def test_meta_schedule_task_scheduler_NIE(): +def test_meta_schedule_task_scheduler_not_implemented_error(): # pylint: disable=invalid-name class MyTaskScheduler(PyTaskScheduler): pass @@ -231,7 +255,7 @@ class MyTaskScheduler(PyTaskScheduler): MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase()) -def test_meta_schedule_task_scheduler_override_next_task_id_only(): +def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name class MyTaskScheduler(PyTaskScheduler): done = set() @@ -288,11 +312,27 @@ def next_task_id(self) -> int: ), ] database = DummyDatabase() - scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database) + scheduler = MyTaskScheduler( + tasks, + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[ + measure_callback.AddToDatabase(), + ], + ) scheduler.tune() assert len(database) == num_trials_total * len(tasks) for task in tasks: - assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + assert ( + len( + database.get_top_k( + database.commit_workload(task.mod), + 100000, + ) + ) + == num_trials_total + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py new file mode 100644 index 0000000000..19caa072ed --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_tune_te.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import logging + +import pytest +from tvm.meta_schedule import tune_te, ReplayTraceConfig +from tvm.meta_schedule.testing import te_workload +from tvm.target.target import Target +from tvm.tir import Schedule + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +@pytest.mark.skip("Integration test") +def test_tune_matmul(): + sch: Schedule = tune_te( + tensors=te_workload.batch_matmul_nkkm(B=1, N=128, M=128, K=128), + target=Target("llvm --num-cores=16"), + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == """__main__""": + test_tune_matmul() diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py new file mode 100644 index 0000000000..43e81fe401 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import logging + +import pytest +from tvm.meta_schedule import ReplayTraceConfig, tune_tir +from tvm.script import tir as T +from tvm.target.target import Target +from tvm.tir import Schedule + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +@pytest.mark.skip("Integration test") +def test_tune_matmul_cpu(): + sch: Schedule = tune_tir( + mod=matmul, + target=Target("llvm --num-cores=16"), + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +@pytest.mark.skip("Integration test") +def test_tune_matmul_cuda(): + sch: Schedule = tune_tir( + mod=matmul, + target=Target("nvidia/geforce-rtx-3070"), + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == """__main__""": + test_tune_matmul_cpu() + test_tune_matmul_cuda() diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py new file mode 100644 index 0000000000..0bad0154a6 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import List + +from tvm.tir import ( + Evaluate, + For, + ForKind, + IndexMap, + Var, + decl_buffer, + floordiv, + floormod, +) +from tvm.tir.analysis import expr_deep_equal +from tvm.tir.schedule.analysis import suggest_index_map + + +def _make_vars(*args: str) -> List[Var]: + return [Var(arg, dtype="int32") for arg in args] + + +def _make_loops(loop_vars: List[Var], extents: List[int]) -> List[For]: + assert len(loop_vars) == len(extents) + return [ + For( + loop_var=loop_var, + min_val=0, + extent=extent, + kind=ForKind.SERIAL, + body=Evaluate(0), + ) + for loop_var, extent in zip(loop_vars, extents) + ] + + +def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: + iters_1 = map1.apply(map2.src_iters) + iters_2 = map2.tgt_iters + assert len(iters_1) == len(iters_2) + for iter1, iter2 in zip(iters_1, iters_2): + assert expr_deep_equal(iter1, iter2) + + +def test_suggest_index_map_simple(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8, 256]), + indices=[ + floordiv(i, 16) * 4 + floordiv(j, 16), + floormod(i, 16) * 16 + floormod(j, 16), + ], + loops=_make_loops( + loop_vars=[i, j], + extents=[32, 64], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x, y: [ + floordiv(x, 4), + floordiv(y, 16), + floormod(x, 4), + floormod(y, 16), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +def test_suggest_index_map_bijective(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8]), + indices=[floormod(j, 4) * 2 + i], + loops=_make_loops( + loop_vars=[i, j], + extents=[2, 32], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x: [ + floormod(x, 2), + floordiv(x, 2), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +if __name__ == "__main__": + test_suggest_index_map_simple() + test_suggest_index_map_bijective() diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py new file mode 100644 index 0000000000..335f5027db --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -0,0 +1,220 @@ +import sys +import pytest +import tvm +from tvm import tir, te +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def blockize(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, (128, 128), "float32") + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(8, 8): + with T.block("blockized_B"): + vi, vj = T.axis.remap("SS", [i, j]) + for ii, jj in T.grid(16, 16): + with T.block("B"): + vii = T.axis.S(128, vi * 16 + ii) + vjj = T.axis.S(128, vj * 16 + jj) + B[vii, vjj] = A[vii, vjj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def blockize_schedule_1(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0_outer in range(0, 8): + for i1_outer in range(0, 8): + with T.block("blockized_B"): + vio = T.axis.S(8, i0_outer) + vjo = T.axis.S(8, i1_outer) + T.reads([A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + T.writes([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + for i0_inner in range(0, 16): + for i1_inner in range(0, 16): + with T.block("B"): + vi = T.axis.S(128, ((vio * 16) + i0_inner)) + vj = T.axis.S(128, ((vjo * 16) + i1_inner)) + T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + with T.block("blockized_C"): + vio = T.axis.S(8, i0_outer) + vjo = T.axis.S(8, i1_outer) + T.reads([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + T.writes([C[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + for ax0 in range(0, 16): + for ax1 in range(0, 16): + with T.block("C"): + vi = T.axis.S(128, ((vio * 16) + ax0)) + vj = T.axis.S(128, ((vjo * 16) + ax1)) + T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def blockize_schedule_2(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0_outer in range(0, 4): + for i1_outer in range(0, 4): + for ax0 in range(0, 2): + for ax1 in range(0, 2): + with T.block("blockized_B"): + vio = T.axis.S(8, ((i0_outer * 2) + ax0)) + vjo = T.axis.S(8, ((i1_outer * 2) + ax1)) + T.reads( + [A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] + ) + T.writes( + [B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] + ) + for i0_inner in range(0, 16): + for i1_inner in range(0, 16): + with T.block("B"): + vi = T.axis.S(128, ((vio * 16) + i0_inner)) + vj = T.axis.S(128, ((vjo * 16) + i1_inner)) + T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i0_inner_1 in range(0, 32): + for i1_inner_1 in range(0, 32): + with T.block("C"): + vi = T.axis.S(128, ((i0_outer * 32) + i0_inner_1)) + vj = T.axis.S(128, ((i1_outer * 32) + i1_inner_1)) + T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) + C[vi, vj] = B[vi, vj] + T.float32(1) + +@T.prim_func +def rowsum(a: T.handle, b:T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128,]) + for k, i in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap('RS', [k, i]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_blockized(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128]) + with T.block("blockized_B"): + vko = T.axis.R(1, 0) + vio = T.axis.S(1, 0) + with T.init(): + for i1 in T.serial(0, 128): + with T.block("B_init"): + vi_init = T.axis.S(128, i1) + B[vi_init] = T.float32(0) + for i0, i1_1 in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [i0, i1_1]) + B[vi] = (B[vi] + A[vi, vk]) + +def test_blockize(): + func = elementwise + # schedule + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + _ = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + tvm.ir.assert_structural_equal(blockize, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_schedule(): + func = elementwise + # test 1 + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.reverse_compute_at(C, yo) + s.blockize(s.get_loops(C)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + verify_trace_roundtrip(sch=s, mod=func) + # test 2 + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(C) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.compute_at(B, yo) + s.blockize(s.get_loops(B)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + verify_trace_roundtrip(sch=s, mod=func) + # test 3 + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + b_outer = s.blockize(xi) + xC, yC = s.get_loops(C) + xCo, xCi = s.split(xC, factors=[None, 32]) + yCo, yCi = s.split(yC, factors=[None, 32]) + s.reorder(xCo, yCo, xCi, yCi) + s.compute_at(b_outer, yCo) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_2) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_init_loops(): + s = tir.Schedule(rowsum, debug_mask='all') + k, _ = s.get_loops(s.get_block("B")) + s.blockize(k) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) + verify_trace_roundtrip(sch=s, mod=rowsum) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index a078c0ed4c..94683e6e1f 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -329,6 +329,27 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: B[vi] = A_cache[vi] * 2.0 + 1.0 +@T.prim_func +def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + compute = T.match_buffer(var_compute, [512, 512], dtype="float32") + C = T.alloc_buffer([512, 512], dtype="float32") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads([C[i, j], A[i, k], B[k, j]]) + T.writes([C[i, j]]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + for i0, i1 in T.grid(512, 512): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads([C[i0_1, i1_1]]) + T.writes([compute[i0_1, i1_1]]) + compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + # pylint: enable=no-member,invalid-name,unused-variable @@ -458,6 +479,13 @@ def test_buffer_matched(): sch.compute_inline(block_b) +def test_output_block(): + sch = tir.Schedule(matmul_relu, debug_mask="all") + block= sch.get_block("compute") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block) + + def test_compute_inline_predicate(): sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B") diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py b/tests/python/unittest/test_tir_schedule_read_write_at.py new file mode 100644 index 0000000000..79a7aad10f --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_read_write_at.py @@ -0,0 +1,221 @@ +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable + +@T.prim_func +def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for k1 in T.unroll(0, 8): + for _, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + +@T.prim_func +def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C_shared[vi, vj]]) + with T.init(): + C_shared[vi, vj] = T.float32(0) + C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + with T.block("C_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(32, bx) + T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 64): + C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable +# fmt: on + + +def test_read_at_global_to_shared_a(): + sch = tir.Schedule(cuda_matmul, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 1, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) + verify_trace_roundtrip(sch, cuda_matmul) + + +def test_read_at_global_to_shared_ab(): + sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 2, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) + verify_trace_roundtrip(sch, cuda_matmul_read_at_a) + + +def test_read_at_local_to_shared_c(): + sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.write_at(tx, block, 0, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) + verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py new file mode 100644 index 0000000000..19b19c2e09 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -0,0 +1,394 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def desc_func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def intrin_func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + # These access region must be explicitly stated. Otherwise the auto-completed region starts from (0, 0) instead of (vi, vj) + T.reads([A[vi: vi+16, vk: vk+16], B[vj: vj+16, vk: vk+16], C[vi:vi+16, vj:vj+16]]) + T.writes([C[vi: vi+16, vj: vj+16]]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + B[vjj, vkk] * A[vii, vkk] + + + +@T.prim_func +def lower_intrin_func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([C[vi:vi + 16, vj:vj + 16], A[vi:vi + 16, vk:vk + 16], B[vj:vj + 16, vk:vk + 16]]) + T.writes(C[vi:vi + 16, vj:vj + 16]) + T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256, + A.data, A.elem_offset // 256, + B.data, B.elem_offset // 256, + C.data, C.elem_offset // 256, + dtype="handle")) + + +@T.prim_func +def tensorized_func(a: T.handle, b: T.handle, c: T.handle) -> None: + # function attr dict + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + for i_outer, j_outer in T.grid(8, 8): + for i_inner_init, j_inner_init in T.grid(16, 16): + with T.block("init"): + vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) + vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) + C[vi_init, vj_init] = T.float32(0) + for k_outer in T.grid(8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) + T.reads([C[vi*16:vi*16 + 16, vj*16:vj*16 + 16], A[vi*16:vi*16 + 16, vk*16:vk*16 + 16], B[vj*16:vj*16 + 16, vk*16:vk*16 + 16]]) + T.writes(C[vi*16:vi*16 + 16, vj*16:vj*16 + 16]) + A_elem_offset = T.var('int32') + B_elem_offset = T.var('int32') + C_elem_offset = T.var('int32') + A_sub = T.match_buffer(A[vi*16:vi*16+16, vk*16:vk*16+16], [16, 16], elem_offset=A_elem_offset) + B_sub = T.match_buffer(B[vj*16:vj*16+16, vk*16:vk*16+16], [16, 16], elem_offset=B_elem_offset) + C_sub = T.match_buffer(C[vi*16:vi*16+16, vj*16:vj*16+16], [16, 16], elem_offset=C_elem_offset) + T.evaluate( + T.tvm_mma_sync(C_sub.data, T.floordiv(C_sub.elem_offset, 256), + A_sub.data, T.floordiv(A_sub.elem_offset, 256), + B_sub.data, T.floordiv(B_sub.elem_offset, 256), + C_sub.data, T.floordiv(C_sub.elem_offset, 256), + dtype="handle")) + + +@T.prim_func +def batch_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 128, 128]) + B = T.match_buffer(b, [16, 128, 128]) + C = T.match_buffer(c, [16, 128, 128]) + + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@T.prim_func +def tensorized_batch_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + # function attr dict + C = T.match_buffer(c, [16, 128, 128]) + B = T.match_buffer(b, [16, 128, 128]) + A = T.match_buffer(a, [16, 128, 128]) + + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + # body + for n in range(0, 16): + for i, j, k in T.grid(8, 8, 8): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + T.reads([C[vn:vn + 1, vi*16:vi*16 + 16, vj*16:vj*16 + 16], A[vn:vn + 1, vi*16:vi*16 + 16, vk*16:vk*16 + 16], + B[vn:vn + 1, vj*16:vj*16 + 16, vk*16:vk*16 + 16]]) + T.writes(C[vn:vn + 1, vi*16:vi*16 + 16, vj*16:vj*16 + 16]) + A_elem_offset = T.var('int32') + B_elem_offset = T.var('int32') + C_elem_offset = T.var('int32') + A_sub = T.match_buffer(A[vn:vn + 1, vi*16:vi*16+16,vk*16:vk*16+16], (16, 16), elem_offset=A_elem_offset) + B_sub = T.match_buffer(B[vn:vn + 1, vj*16:vj*16+16,vk*16:vk*16+16], (16, 16), elem_offset=B_elem_offset) + C_sub = T.match_buffer(C[vn:vn + 1, vi*16:vi*16+16,vj*16:vj*16+16], (16, 16), elem_offset=C_elem_offset) + T.evaluate( + T.tvm_mma_sync(C_sub.data, T.floordiv(C_sub.elem_offset, 256), + A_sub.data, T.floordiv(A_sub.elem_offset, 256), + B_sub.data, T.floordiv(B_sub.elem_offset, 256), + C_sub.data, T.floordiv(C_sub.elem_offset, 256), + dtype="handle")) + + +@T.prim_func +def batch_matmul_dot_product(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [1, 4, 4], "float32") + B = T.match_buffer(b, [1, 4, 4], "float32") + C = T.match_buffer(c, [1, 4, 4], "float32") + + t = T.var("int32") + T.attr(T.iter_var(t, None, "DataPar", ""), "pragma_import_llvm", + "; ModuleID = '/tmp/tmpur44d1nu/input0.cc'\n\ +source_filename = \"/tmp/tmpur44d1nu/input0.cc\"\n\ +target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"\n\ +target triple = \"x86_64-pc-linux-gnu\"\n\ +\n\ +; Function Attrs: noinline nounwind optnone uwtable\n\ +define dso_local i32 @vec4add(float* %0, i32 %1, float* %2, i32 %3, float* %4, i32 %5) #0 {\n\ + %7 = alloca float*, align 8\n\ + %8 = alloca i32, align 4\n\ + %9 = alloca float*, align 8\n\ + %10 = alloca i32, align 4\n\ + %11 = alloca float*, align 8\n\ + %12 = alloca i32, align 4\n\ + %13 = alloca i32, align 4\n\ + store float* %0, float** %7, align 8\n\ + store i32 %1, i32* %8, align 4\n\ + store float* %2, float** %9, align 8\n\ + store i32 %3, i32* %10, align 4\n\ + store float* %4, float** %11, align 8\n\ + store i32 %5, i32* %12, align 4\n\ + store i32 0, i32* %13, align 4\n\ + br label %14\n\ +\n\ +14: ; preds = %39, %6\n\ + %15 = load i32, i32* %13, align 4\n\ + %16 = icmp slt i32 %15, 4\n\ + br i1 %16, label %17, label %42\n\ +\n\ +17: ; preds = %14\n\ + %18 = load float*, float** %9, align 8\n\ + %19 = load i32, i32* %13, align 4\n\ + %20 = load i32, i32* %10, align 4\n\ + %21 = add nsw i32 %19, %20\n\ + %22 = sext i32 %21 to i64\n\ + %23 = getelementptr inbounds float, float* %18, i64 %22\n\ + %24 = load float, float* %23, align 4\n\ + %25 = load float*, float** %11, align 8\n\ + %26 = load i32, i32* %13, align 4\n\ + %27 = load i32, i32* %12, align 4\n\ + %28 = add nsw i32 %26, %27\n\ + %29 = sext i32 %28 to i64\n\ + %30 = getelementptr inbounds float, float* %25, i64 %29\n\ + %31 = load float, float* %30, align 4\n\ + %32 = fmul float %24, %31\n\ + %33 = load float*, float** %7, align 8\n\ + %34 = load i32, i32* %8, align 4\n\ + %35 = sext i32 %34 to i64\n\ + %36 = getelementptr inbounds float, float* %33, i64 %35\n\ + %37 = load float, float* %36, align 4\n\ + %38 = fadd float %37, %32\n\ + store float %38, float* %36, align 4\n\ + br label %39\n\ +\n\ +39: ; preds = %17\n\ + %40 = load i32, i32* %13, align 4\n\ + %41 = add nsw i32 %40, 1\n\ + store i32 %41, i32* %13, align 4\n\ + br label %14\n\ +\n\ +42: ; preds = %14\n\ + ret i32 0\n\ +}\n\ +\n\ +attributes #0 = { noinline nounwind optnone uwtable \"correctly-rounded-divide-sqrt-fp-math\"=\"false\" \"disable-tail-calls\"=\"false\" \"frame-pointer\"=\"all\" \"less-precise-fpmad\"=\"false\" \"min-legal-vector-width\"=\"0\" \"no-infs-fp-math\"=\"false\" \"no-jump-tables\"=\"false\" \"no-nans-fp-math\"=\"false\" \"no-signed-zeros-fp-math\"=\"false\" \"no-trapping-math\"=\"true\" \"stack-protector-buffer-size\"=\"8\" \"target-cpu\"=\"x86-64\" \"target-features\"=\"+cx8,+fxsr,+mmx,+sse,+sse2,+x87\" \"unsafe-fp-math\"=\"false\" \"use-soft-float\"=\"false\" }\n\ +\n\ +!llvm.module.flags = !{!0}\n\ +!llvm.ident = !{!1}\n\ +\n\ +!0 = !{i32 1, !\"wchar_size\", i32 4}\n\ +!1 = !{!\"Ubuntu clang version 11.0.0-++20200928083541+eb83b551d3e-1~exp1~20200928184208.110\"}\n\ +\n\ + ") + + for n, i, j in T.grid(1, 4, 4): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + + for n, i, j, k in T.grid(1, 4, 4, 4): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.R(4, v0 + i) + C[0] = C[0] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), offset_factor=1) + B = T.match_buffer(b, (4,), offset_factor=1) + C = T.match_buffer(c, (1,), offset_factor=1) + + with T.block("root"): + v0 = T.axis.R(4, 0) + T.reads([C[0 : 1], A[v0 : v0 + 4], B[v0 : v0 + 4]]) + T.writes([C[0 : 1]]) + T.evaluate(T.call_extern("vec4add", C.data, C.elem_offset, A.data, A.elem_offset, B.data, B.elem_offset, dtype="int32")) + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg +# fmt: on + +# pylint: disable=invalid-name + + +tir.TensorIntrin.register('test_identity_intrin', desc_func, intrin_func) +tir.TensorIntrin.register('test_mma_intrin', desc_func, lower_intrin_func) +tir.TensorIntrin.register('test_dot_product_intrin', dot_product_desc, dot_product_impl) + + +def test_tensorize_gemm(): + func = matmul + # schedule + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.decompose_reduction(update, ko) + s.tensorize(ii, 'test_identity_intrin') + + func = tvm.build(s.mod["main"]) + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((128, 128)).astype("float32")) + func(a, b, c) + tvm.testing.assert_allclose(c.numpy(), np.dot(a_np, b_np.transpose()), rtol=1e-6) + + +def test_tensorize_buffer_bind(): + func = matmul + # schedule + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.decompose_reduction(update, ko) + s.tensorize(ii, 'test_mma_intrin') + tvm.ir.assert_structural_equal(tensorized_func, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_high_dim_tensorize(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + _, i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.tensorize(ii, 'test_mma_intrin') + tvm.ir.assert_structural_equal(tensorized_batch_matmul, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=batch_matmul) + + +@pytest.mark.skip("failed") +def test_tensorize_dot_product(): + func = batch_matmul_dot_productt + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, _, _, k = s.get_loops(C) + _, ki = s.split(k, factors=[None, 4]) + s.tensorize(ki, 'test_dot_product_intrin') + target = "llvm" + ctx = tvm.device(target, 0) + a_np = np.random.uniform(size=(1, 4, 4)).astype("float32") + b_np = np.random.uniform(size=(1, 4, 4)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((1, 4, 4), dtype="float32"), ctx) + func = tvm.build(s.mod["main"], target=target) + func(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), + np.matmul(a.numpy(), b.numpy().transpose(0, 2, 1)), + rtol=1e-5, + ) + verify_trace_roundtrip(sch=s, mod=func) + + +if __name__ == "__main__": + test_tensorize_gemm() + test_tensorize_buffer_bind() + test_high_dim_tensorize() + # test_tensorize_dot_product() diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index d75bc1461c..e01d469d8e 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -61,6 +61,46 @@ def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None: D[vi, vj] = T.max(C[vi, vj], 0.0) +@T.prim_func +def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024)) + B = T.match_buffer(b, (1024, 1024)) + C = T.alloc_buffer((1024, 1024)) + D = T.match_buffer(d, (1024, 1024)) + for i in T.serial(0, 1024, annotations={"test1": "aaa"}): + for j in T.serial(0, 1024, annotations={"test2": 612}): + for k in T.serial(0, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) + + +@T.prim_func +def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024)) + B = T.match_buffer(b, (1024, 1024)) + C = T.alloc_buffer((1024, 1024)) + D = T.match_buffer(d, (1024, 1024)) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + T.block_attr({"test1": "aaa"}) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"test2": 0.22}) + D[vi, vj] = T.max(C[vi, vj], 0.0) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -199,5 +239,31 @@ def test_get_consumers(): verify_trace_roundtrip(sch, mod=matmul_relu) +def test_annotate_unannotate_loop(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + matmul = sch.get_block("matmul") + relu = sch.get_block("relu") + sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa") + sch.annotate(sch.get_loops(matmul)[1], "test2", 612) + tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann1) + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + sch.unannotate(sch.get_loops(matmul)[0], "test1") + sch.unannotate(sch.get_loops(matmul)[1], "test2") + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + + +def test_annotate_unannotate_block(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + matmul = sch.get_block("matmul") + relu = sch.get_block("relu") + sch.annotate(matmul, "test1", "aaa") + sch.annotate(relu, "test2", 0.22) + tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann2) + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + sch.unannotate(matmul, "test1") + sch.unannotate(relu, "test2") + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))