diff --git a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py index 1cf556ddd0..0182456099 100644 --- a/gallery/how_to/extend_tvm/bring_your_own_datatypes.py +++ b/gallery/how_to/extend_tvm/bring_your_own_datatypes.py @@ -313,7 +313,7 @@ def convert_ndarray(dst_dtype, array): print(str(e).split("\n")[-1]) ###################################################################### -# When we attempt to run the model, we get a familiar error telling us that more funcions need to be registerd for myfloat. +# When we attempt to run the model, we get a familiar error telling us that more functions need to be registerd for myfloat. # # Because this is a neural network, many more operations are required. # Here, we register all the needed functions: diff --git a/include/tvm/auto_scheduler/cost_model.h b/include/tvm/auto_scheduler/cost_model.h index f7a27895a7..a52c6797b6 100755 --- a/include/tvm/auto_scheduler/cost_model.h +++ b/include/tvm/auto_scheduler/cost_model.h @@ -122,11 +122,11 @@ class RandomModel : public CostModel { * This class will call functions defined in the python */ class PythonBasedModelNode : public CostModelNode { public: - /*! \brief Pointer to the update funcion in python */ + /*! \brief Pointer to the update function in python */ PackedFunc update_func; - /*! \brief Pointer to the predict funcion in python */ + /*! \brief Pointer to the predict function in python */ PackedFunc predict_func; - /*! \brief Pointer to the predict funcion in python */ + /*! \brief Pointer to the predict function in python */ PackedFunc predict_stage_func; void Update(const Array& inputs, const Array& results) final; diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h index 841b6b9530..20a93e280b 100755 --- a/include/tvm/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -236,7 +236,7 @@ class MeasureCallback : public ObjectRef { * This class will call functions defined in the python */ class PythonBasedMeasureCallbackNode : public MeasureCallbackNode { public: - /*! \brief Pointer to the callback funcion in python */ + /*! \brief Pointer to the callback function in python */ PackedFunc callback_func; void Callback(const SearchPolicy& policy, const Array& inputs, diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h new file mode 100644 index 0000000000..dfc816bb00 --- /dev/null +++ b/include/tvm/meta_schedule/cost_model.h @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_COST_MODEL_H_ +#define TVM_META_SCHEDULE_COST_MODEL_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Cost model. */ +class CostModelNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~CostModelNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Load the cost model from given file location. + * \param file_location The file location. + * \return Whether cost model was loaded successfully. + */ + virtual bool Load(const String& file_location) = 0; + + /*! + * \brief Save the cost model to given file location. + * \param file_location The file location. + * \return Whether cost model was saved successfully. + */ + virtual bool Save(const String& file_location) = 0; + + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + */ + virtual void Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) = 0; + + /*! + * \brief Predict the running results of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \return The predicted running results. + */ + virtual std::vector Predict(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.CostModel"; + TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); +}; + +/*! \brief The cost model with customized methods on the python-side. */ +class PyCostModelNode : public CostModelNode { + public: + /*! + * \brief Load the cost model from given file location. + * \param file_location The file location. + * \return Whether cost model was loaded successfully. + */ + using FLoad = runtime::TypedPackedFunc; + /*! + * \brief Save the cost model to given file location. + * \param file_location The file location. + * \return Whether cost model was saved successfully. + */ + using FSave = runtime::TypedPackedFunc; + /*! + * \brief Update the cost model given running results. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param results The running results of the measure candidates. + * \return Whether cost model was updated successfully. + */ + using FUpdate = runtime::TypedPackedFunc&, + const Array&)>; + /*! + * \brief Predict the running results of given measure candidates. + * \param tune_context The tuning context. + * \param candidates The measure candidates. + * \param p_addr The address to save the the estimated running results. + */ + using FPredict = runtime::TypedPackedFunc&, + void* p_addr)>; + /*! + * \brief Get the cost model as string with name. + * \return The string representation of the cost model. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Load` function. */ + FLoad f_load; + /*! \brief The packed function to the `Save` function. */ + FSave f_save; + /*! \brief The packed function to the `Update` function. */ + FUpdate f_update; + /*! \brief The packed function to the `Predict` function. */ + FPredict f_predict; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_load` is not visited + // `f_save` is not visited + // `f_update` is not visited + // `f_predict` is not visited + // `f_as_string` is not visited + } + + bool Load(const String& file_location) { + ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; + return f_load(file_location); + } + + bool Save(const String& file_location) { + ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; + return f_save(file_location); + } + + void Update(const TuneContext& tune_context, const Array& candidates, + const Array& results) { + ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; + f_update(tune_context, candidates, results); + } + + std::vector Predict(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; + std::vector result(candidates.size(), 0.0); + f_predict(tune_context, candidates, result.data()); + return result; + } + + static constexpr const char* _type_key = "meta_schedule.PyCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); +}; + +/*! + * \brief Managed reference to CostModelNode + * \sa CostModelNode + */ +class CostModel : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_load The packed function of `Load`. + * \param f_save The packed function of `Save`. + * \param f_update The packed function of `Update`. + * \param f_predict The packed function of `Predict`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_COST_MODEL_H_ diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h new file mode 100644 index 0000000000..30e2f0fe62 --- /dev/null +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ +#define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Extractor for features from measure candidates for use in cost model. */ +class FeatureExtractorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~FeatureExtractorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + virtual Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) = 0; + + static constexpr const char* _type_key = "meta_schedule.FeatureExtractor"; + TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object); +}; + +/*! \brief The feature extractor with customized methods on the python-side. */ +class PyFeatureExtractorNode : public FeatureExtractorNode { + public: + /*! + * \brief Extract features from the given measure candidate. + * \param tune_context The tuning context for feature extraction. + * \param candidates The measure candidates to extract features from. + * \return The feature ndarray extracted. + */ + using FExtractFrom = runtime::TypedPackedFunc( + const TuneContext& tune_context, const Array& candidates)>; + /*! + * \brief Get the feature extractor as string with name. + * \return The string of the feature extractor. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `ExtractFrom` function. */ + FExtractFrom f_extract_from; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_extract_from` is not visited + // `f_as_string` is not visited + } + + Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) { + ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; + return f_extract_from(tune_context, candidates); + } + + static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); +}; + +/*! + * \brief Managed reference to FeatureExtractorNode + * \sa FeatureExtractorNode + */ +class FeatureExtractor : public runtime::ObjectRef { + public: + /*! + * \brief Create a feature extractor with customized methods on the python-side. + * \param f_extract_from The packed function of `ExtractFrom`. + * \param f_as_string The packed function of `AsString`. + * \return The feature extractor created. + */ + TVM_DLL static FeatureExtractor PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, // + PyFeatureExtractorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index 9ee7039959..67152397dc 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -81,9 +81,9 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `Apply` funcion. */ + /*! \brief The packed function to the `Apply` function. */ FApply f_apply; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 82f5b76834..e1c5ee3be9 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -36,8 +36,9 @@ class MutatorNode : public runtime::Object { void VisitAttrs(tvm::AttrVisitor* v) {} /*! - * \brief The function type of `InitializeWithTuneContext` method. + * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& context) = 0; @@ -72,11 +73,11 @@ class PyMutatorNode : public MutatorNode { */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` funcion. */ + /*! \brief The packed function to the `Apply` function. */ FApply f_apply; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { @@ -110,6 +111,7 @@ class Mutator : public runtime::ObjectRef { * \brief Create a mutator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. * \return The mutator created. */ TVM_DLL static Mutator PyMutator( diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 3a489555c4..1a134a35ab 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -38,8 +38,9 @@ class PostprocNode : public runtime::Object { void VisitAttrs(tvm::AttrVisitor* v) {} /*! - * \brief The function type of `InitializeWithTuneContext` method. + * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& context) = 0; @@ -112,6 +113,7 @@ class Postproc : public runtime::ObjectRef { * \brief Create a postprocessor with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. * \return The postprocessor created. */ TVM_DLL static Postproc PyPostproc( diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index b9e6d87774..9c0eaa8088 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -36,8 +36,9 @@ class ScheduleRuleNode : public runtime::Object { void VisitAttrs(tvm::AttrVisitor* v) {} /*! - * \brief The function type of `InitializeWithTuneContext` method. + * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& context) = 0; @@ -155,6 +156,7 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. * \return The schedule rule created. */ TVM_DLL static ScheduleRule PyScheduleRule( diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index a0dfede820..c0f37c6037 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -102,7 +102,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { */ using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; - /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; /*! \brief The packed function to the `GenerateDesignSpace` function. */ FGenerateDesignSpace f_generate_design_space; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index b8cfe7b381..f0c3a3c208 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -162,9 +162,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { */ using FNextTaskId = runtime::TypedPackedFunc; - /*! \brief The packed function to the `Tune` funcion. */ + /*! \brief The packed function to the `Tune` function. */ FTune f_tune; - /*! \brief The packed function to the `InitializeTask` funcion. */ + /*! \brief The packed function to the `InitializeTask` function. */ FInitializeTask f_initialize_task; /*! \brief The packed function to the `SetTaskStopped` function. */ FSetTaskStopped f_set_task_stopped; diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index c57355a391..37e8ffa9d8 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -25,4 +25,6 @@ from . import space_generator from . import search_strategy from . import integration +from . import feature_extractor +from . import cost_model from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py new file mode 100644 index 0000000000..7267c5ae54 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.cost_model package. +""" +from .cost_model import CostModel, PyCostModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py new file mode 100644 index 0000000000..da29c7db66 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule CostModel.""" + +from typing import List +import ctypes + +import numpy as np + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api +from ..runner import RunnerResult +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..utils import _get_hex_address, check_override + + +@register_object("meta_schedule.CostModel") +class CostModel(Object): + """Cost model.""" + + def load(self, file_location: str) -> bool: + """Load the cost model from given file location. + + Parameters + ---------- + file_location : str + The file location. + + Return + ------ + result : bool + Whether cost model was loaded successfully. + """ + return bool(_ffi_api.CostModelLoad(self, file_location)) # type: ignore # pylint: disable=no-member + + def save(self, file_location: str) -> bool: + """Save the cost model to given file location. + + Parameters + ---------- + file_location : str + The file location. + + Return + ------ + result : bool + Whether cost model was saved successfully. + """ + return bool(_ffi_api.CostModelSave(self, file_location)) # type: ignore # pylint: disable=no-member + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member + + def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : bool + The predicted running results. + """ + n = len(candidates) + results = np.zeros(shape=(n,), dtype="float64") + _ffi_api.CostModelPredict( # type: ignore # pylint: disable=no-member + self, + tune_context, + candidates, + results.ctypes.data_as(ctypes.c_void_p), + ) + return results + + +@register_object("meta_schedule.PyCostModel") +class PyCostModel(CostModel): + """An abstract CostModel with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, CostModel) + def f_load(file_location: str) -> bool: + return self.load(file_location) + + @check_override(self.__class__, CostModel) + def f_save(file_location: str) -> bool: + return self.save(file_location) + + @check_override(self.__class__, CostModel) + def f_update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> bool: + self.update(tune_context, candidates, results) + + @check_override(self.__class__, CostModel) + def f_predict( + tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr + ) -> None: + n = len(candidates) + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) + array_wrapper[:] = self.predict(tune_context, candidates) + assert ( + array_wrapper.dtype == "float64" + ), "ValueError: Invalid data type returned from CostModel Predict!" + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.CostModelPyCostModel, # type: ignore # pylint: disable=no-member + f_load, + f_save, + f_update, + f_predict, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py new file mode 100644 index 0000000000..49310decf3 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.feature_extractor package. +Meta Schedule feature extractors that extracts features from +measure candidates for use in cost model. +""" +from .feature_extractor import FeatureExtractor, PyFeatureExtractor diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py new file mode 100644 index 0000000000..4126e0fb45 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule FeatureExtractor.""" + +from typing import List + +from tvm._ffi import register_object +from tvm.runtime import Object, NDArray + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate + + +@register_object("meta_schedule.FeatureExtractor") +class FeatureExtractor(Object): + """Extractor for features from measure candidates for use in cost model.""" + + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + """Extract features from the given measure candidate. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for feature extraction. + candidates : List[MeasureCandidate] + The measure candidates to extract features from. + + Returns + ------- + features : List[NDArray] + The feature ndarray extracted. + """ + return _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member + self, tune_context, candidates + ) + + +@register_object("meta_schedule.PyFeatureExtractor") +class PyFeatureExtractor(FeatureExtractor): + """An abstract feature extractor with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, FeatureExtractor) + def f_extract_from( + tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + return self.extract_from(tune_context, candidates) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.FeatureExtractorPyFeatureExtractor, # type: ignore # pylint: disable=no-member + f_extract_from, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc new file mode 100644 index 0000000000..5cd32b097c --- /dev/null +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // + PyCostModelNode::FSave f_save, // + PyCostModelNode::FUpdate f_update, // + PyCostModelNode::FPredict f_predict, // + PyCostModelNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_load = std::move(f_load); + n->f_save = std::move(f_save); + n->f_update = std::move(f_update); + n->f_predict = std::move(f_predict); + n->f_as_string = std::move(f_as_string); + return CostModel(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyCostModelNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(CostModelNode); +TVM_REGISTER_NODE_TYPE(PyCostModelNode); + +TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate") + .set_body_method(&CostModelNode::Update); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") + .set_body_typed([](CostModel model, // + const TuneContext& tune_context, // + Array candidates, // + void* p_addr) -> void { + std::vector result = model->Predict(tune_context, candidates); + std::copy(result.begin(), result.end(), static_cast(p_addr)); + }); +TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc new file mode 100644 index 0000000000..84d22493aa --- /dev/null +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +FeatureExtractor FeatureExtractor::PyFeatureExtractor( + PyFeatureExtractorNode::FExtractFrom f_extract_from, // + PyFeatureExtractorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_extract_from = std::move(f_extract_from); + n->f_as_string = std::move(f_as_string); + return FeatureExtractor(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyFeatureExtractorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyFeatureExtractor's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); +TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") + .set_body_method(&FeatureExtractorNode::ExtractFrom); +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") + .set_body_typed(FeatureExtractor::PyFeatureExtractor); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 15d0c3f9f8..3a41062be2 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,7 +23,9 @@ #include #include #include +#include #include +#include #include #include #include diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index b5780975f4..44dc323186 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -79,7 +79,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& void InitContextFunctions(std::function fgetsymbol); /*! - * \brief Type alias for funcion to wrap a TVMBackendPackedCFunc. + * \brief Type alias for function to wrap a TVMBackendPackedCFunc. * \param The function address imported from a module. * \param mptr The module pointer node. * \return Packed function that wraps the invocation of the function at faddr. diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py new file mode 100644 index 0000000000..d1ce14e5a4 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List + +import re +import numpy as np + +import tvm +from tvm.script import tir as T +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.runner import RunnerResult +from tvm.tir.schedule.schedule import Schedule +from tvm.meta_schedule.cost_model import PyCostModel + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,disable=unused-argument + + +def test_meta_schedule_cost_model(): + class FancyCostModel(PyCostModel): + def load(self, file_location: str) -> bool: + return True + + def save(self, file_location: str) -> bool: + return True + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10) + + model = FancyCostModel() + assert model.save("fancy_test_location") + assert model.load("fancy_test_location") + model.update(TuneContext(), [], []) + results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) + assert results.shape == (10,) + + +def test_meta_schedule_cost_model_as_string(): + class NotSoFancyCostModel(PyCostModel): + def load(self, file_location: str) -> bool: + return True + + def save(self, file_location: str) -> bool: + return True + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + return np.random.rand(10) + + cost_model = NotSoFancyCostModel() + pattern = re.compile(r"NotSoFancyCostModel\(0x[a-f|0-9]*\)") + assert pattern.match(str(cost_model)) + + +if __name__ == "__main__": + test_meta_schedule_cost_model() + test_meta_schedule_cost_model_as_string() diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py new file mode 100644 index 0000000000..05b2bae40b --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +import numpy as np +import re + +from tvm.runtime import NDArray +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.feature_extractor import PyFeatureExtractor + + +def test_meta_schedule_feature_extractor(): + class FancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + return [np.random.rand(4, 5)] + + extractor = FancyFeatureExtractor() + features = extractor.extract_from(TuneContext(), []) + assert len(features) == 1 + assert features[0].shape == (4, 5) + + +def test_meta_schedule_feature_extractor_as_string(): + class NotSoFancyFeatureExtractor(PyFeatureExtractor): + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + return None + + feature_extractor = NotSoFancyFeatureExtractor() + pattern = re.compile(r"NotSoFancyFeatureExtractor\(0x[a-f|0-9]*\)") + assert pattern.match(str(feature_extractor)) + + +if __name__ == "__main__": + test_meta_schedule_feature_extractor() + test_meta_schedule_feature_extractor_as_string()