diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index 8e016f059a..adc1955a1e 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -39,17 +39,15 @@ class CostModelNode : public runtime::Object { /*! * \brief Load the cost model from given file location. - * \param file_location The file location. - * \return Whether cost model was loaded successfully. + * \param path The file path. */ - virtual bool Load(const String& file_location) = 0; + virtual void Load(const String& path) = 0; /*! * \brief Save the cost model to given file location. - * \param file_location The file location. - * \return Whether cost model was saved successfully. + * \param path The file path. */ - virtual bool Save(const String& file_location) = 0; + virtual void Save(const String& path) = 0; /*! * \brief Update the cost model given running results. @@ -78,16 +76,14 @@ class PyCostModelNode : public CostModelNode { public: /*! * \brief Load the cost model from given file location. - * \param file_location The file location. - * \return Whether cost model was loaded successfully. + * \param path The file path. */ - using FLoad = runtime::TypedPackedFunc; + using FLoad = runtime::TypedPackedFunc; /*! * \brief Save the cost model to given file location. - * \param file_location The file location. - * \return Whether cost model was saved successfully. + * \param path The file path. */ - using FSave = runtime::TypedPackedFunc; + using FSave = runtime::TypedPackedFunc; /*! * \brief Update the cost model given running results. * \param tune_context The tuning context. @@ -130,16 +126,15 @@ class PyCostModelNode : public CostModelNode { // `f_as_string` is not visited } - bool Load(const String& file_location) { + void Load(const String& path) { ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; - return f_load(file_location); + f_load(path); } - bool Save(const String& file_location) { + void Save(const String& path) { ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; - return f_save(file_location); + f_save(path); } - void Update(const TuneContext& tune_context, const Array& candidates, const Array& results) { ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py index 7267c5ae54..8fc6f04ac9 100644 --- a/python/tvm/meta_schedule/cost_model/__init__.py +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -18,3 +18,5 @@ The tvm.meta_schedule.cost_model package. """ from .cost_model import CostModel, PyCostModel +from .random_model import RandomModel +from .xgb_model import XGBModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index da29c7db66..0cbba42a31 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -35,35 +35,25 @@ class CostModel(Object): """Cost model.""" - def load(self, file_location: str) -> bool: + def load(self, path: str) -> None: """Load the cost model from given file location. Parameters ---------- - file_location : str - The file location. - - Return - ------ - result : bool - Whether cost model was loaded successfully. + path : str + The file path. """ - return bool(_ffi_api.CostModelLoad(self, file_location)) # type: ignore # pylint: disable=no-member + _ffi_api.CostModelLoad(self, path) # type: ignore # pylint: disable=no-member - def save(self, file_location: str) -> bool: + def save(self, path: str) -> None: """Save the cost model to given file location. Parameters ---------- - file_location : str - The file location. - - Return - ------ - result : bool - Whether cost model was saved successfully. + path : str + The file path. """ - return bool(_ffi_api.CostModelSave(self, file_location)) # type: ignore # pylint: disable=no-member + _ffi_api.CostModelSave(self, path) # type: ignore # pylint: disable=no-member def update( self, @@ -96,7 +86,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) Return ------ - result : bool + result : np.ndarray The predicted running results. """ n = len(candidates) @@ -118,12 +108,12 @@ def __init__(self): """Constructor.""" @check_override(self.__class__, CostModel) - def f_load(file_location: str) -> bool: - return self.load(file_location) + def f_load(path: str) -> None: + self.load(path) @check_override(self.__class__, CostModel) - def f_save(file_location: str) -> bool: - return self.save(file_location) + def f_save(path: str) -> None: + self.save(path) @check_override(self.__class__, CostModel) def f_update( @@ -131,7 +121,7 @@ def f_update( tune_context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], - ) -> bool: + ) -> None: self.update(tune_context, candidates, results) @check_override(self.__class__, CostModel) diff --git a/python/tvm/meta_schedule/cost_model/metric.py b/python/tvm/meta_schedule/cost_model/metric.py new file mode 100644 index 0000000000..7eb6da6f07 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/metric.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Cost model metrics for meta schedule""" +from typing import List +import numpy as np + + +def max_curve(trial_scores: np.ndarray) -> List[float]: + """f(n) = max([s[i] fo i < n]) + + Parameters + ---------- + trial_scores : List[float] + the score of i-th trial + + Returns + ------- + curve : List[float] + function values + """ + ret = np.empty(len(trial_scores)) + keep = -1e9 + for i, score in enumerate(trial_scores): + keep = max(keep, score) + ret[i] = keep + return ret diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py new file mode 100644 index 0000000000..56c65f64af --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Random cost model +""" +from typing import List, Union, Tuple, Optional + +import numpy as np + +from ..runner import RunnerResult +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..cost_model import PyCostModel + + +class RandomModel(PyCostModel): + """Random cost model + + Parameters + ---------- + random_state : Union[Tuple[str, np.ndarray, int, int, float], dict] + The random state of the random number generator. + path : Optional[str] + The path of the random cost model. + max_range : Optional[int] + The maximum range of random results, [0, max_range]. + + Reference + --------- + https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html + """ + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + """ + self.random_state = tuple(np.load(path, allow_pickle=True)) + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + """ + np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + + def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted running results. + """ + np.random.set_state(self.random_state) + # todo(@zxybazh): Use numpy's RandState object: + # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py new file mode 100644 index 0000000000..441cf1cbbc --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -0,0 +1,665 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +XGBoost-based cost model +""" +from typing import NamedTuple, Optional, Tuple, Callable, List, TYPE_CHECKING + +import os +import logging +import tempfile +from itertools import chain as itertools_chain +import numpy as np + +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..feature_extractor import FeatureExtractor +from ..cost_model import PyCostModel +from ..utils import cpu_count +from .metric import max_curve +from ...contrib.tar import tar, untar + +if TYPE_CHECKING: + from ..tune_context import TuneContext + import xgboost as xgb + + +logger = logging.getLogger(__name__) + + +def make_metric_sorter(focused_metric): + """ Make sure the focused metric is the first one. """ + + def metric_name_for_sort(name): + if focused_metric == name: + return "!" + name + return name + + def sort_key(key): + key, _ = key + return metric_name_for_sort(key) + + return sort_key + + +class PackSum: + """The pack-sum format + + Parameters + ---------- + dmatrix : xgb.DMatrix + A float64 array of shape [n, m], + where `n` is the packed number of blocks, + and `m` is the length of feature vector on each block + ids : np.ndarray + An int64 array of shape [n] containing nonnegative integers, + indicating which the index of a sample that a block belongs to + """ + + dmatrix: "xgb.DMatrix" # type: ignore # pylint: disable=invalid-name + ids: np.ndarray + + def __init__( + self, + xs: List[np.ndarray], + ys: Optional[List[float]], + ): + """Create PackSum format given a batch of samples + + Parameters + ---------- + xs : List[np.ndarray] + A batch of input samples + ys : Optional[List[float]] + A batch of labels. None means no lables available. + """ + import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel + + repeats = [x.shape[0] for x in xs] + xs = np.concatenate(xs, axis=0) + self.ids = np.concatenate([[i] * repeat for i, repeat in enumerate(repeats)], axis=0) + if ys is None: + self.dmatrix = xgb.DMatrix(data=xs, label=None) + else: + ys = np.concatenate([[y] * repeat for y, repeat in zip(ys, repeats)], axis=0) + self.dmatrix = xgb.DMatrix(data=xs, label=ys) + self.dmatrix.set_weight(ys) + + def predict_with_score(self, pred: np.ndarray) -> np.ndarray: + """Predict the labels given the block level prediction scores. + + Parameters + ---------- + pred : np.ndarray + The block level predictions + + Returns + ------- + result : np.ndarray + The predictions for each candidate. + """ + return np.bincount(self.ids, weights=pred) + + def obj_square_error(self, ys_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Implement square error loss on pack-sum format as + a custom objective function for xgboost. + + Parameters + ---------- + ys_pred: np.ndarray + The predictions + + Returns + ------- + gradient: np.ndarray + The gradient according to the xgboost format + hessian: np.ndarray + The hessian according to the xgboost format + """ + # Making prediction + ys_pred = self.predict_with_score(ys_pred) + # Propagate prediction to each block + ys_pred = ys_pred[self.ids] + # The gradient and hessian + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + gradient = ys_pred - ys + hessian = np.ones_like(gradient) + return gradient * ys, hessian * ys + + def rmse(self, ys_pred: np.ndarray) -> Tuple[str, float]: + """Evaluate RMSE (rooted mean square error) in the pack-sum format + + Parameters + ---------- + ys_pred: np.ndarray + The raw predictions + + Returns + ------- + name: str + The name of the metric + score: float + The score of the metric + """ + # Making prediction + ys_pred = self.predict_with_score(ys_pred) + # Propagate prediction to each block + ys_pred = ys_pred[self.ids] + # The RMSE + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + square_error = np.square(ys_pred - ys) + rmse = np.sqrt(square_error.mean()) + return "p-rmse", rmse + + def average_peak_score( + self, + ys_pred: np.ndarray, + n: int, + ) -> Tuple[str, float]: + """Evaluate average-peak-score@N in the pack-sum format + + Parameters + ---------- + ys_pred: np.ndarray + The raw prediction + n : int + The N in average-peak-score@N + + Returns + ------- + name: str + The name of the metric + score: float + The score of the metric + """ + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + ys = self.predict_with_score(ys) # type: ignore # pylint: disable=invalid-name + ys = ys / np.unique(self.ids, return_counts=True)[1] # type: ignore # pylint: disable=invalid-name + ys_pred = self.predict_with_score(ys_pred) + trials = np.argsort(ys_pred)[::-1][:n] + trial_scores = ys[trials] + curve = max_curve(trial_scores) / np.max(ys) + score = np.mean(curve) + return f"a-peak@{n}", score + + +class XGBConfig(NamedTuple): + """XGBoost model configuration + + Parameters + ---------- + max_depth : int + The maximum depth. + gamma : float + The gamma. + min_child_weight : float + The minimum child weight. + eta : float + The eta, learning rate. + seed : int + The random seed. + nthread : Optional[int], + The number of threads to use. + Default is None, which means to use physical number of cores. + """ + + def to_dict(self): + xgb_params = { + "max_depth": self.max_depth, + "gamma": self.gamma, + "min_child_weight": self.min_child_weight, + "eta": self.eta, + "seed": self.seed, + "nthread": self.nthread, + } + return xgb_params + + max_depth: int = 10 + gamma: float = 0.001 + min_child_weight: float = 0 + eta: float = 0.2 + seed: int = 43 + nthread: Optional[int] = None + + +class XGBModel(PyCostModel): + """XGBoost model + + Parameters + ---------- + extractor : FeatureExtractor + The feature extractor for the model. + config : XGBConfig + The XGBoost model config. + num_warmup_samples : int + The number of samples that are used for warmup, i.e., the first few samples are predicted + with random results. + early_stopping_rounds : int + The number of rounds for early stopping. + verbose_eval : int + The verbose level when doing evaluation. + average_peak_n : int + The number to calculate average peak score. + """ + + # feature extractor + extractor: FeatureExtractor + # xgboost model config + config: XGBConfig + # behavior of randomness + num_warmup_samples: int + # evaluation + early_stopping_rounds: int + verbose_eval: int + average_peak_n: int + # states + cached_features: List[np.ndarray] + cached_mean_costs: np.ndarray + cached_normalizer: Optional[float] + booster: Optional["xgb.Booster"] + + def __init__( + self, + *, + # feature extractor + extractor: FeatureExtractor, + # xgboost model config + config: XGBConfig = XGBConfig(), + # load from disk + path: Optional[str] = None, + # behavior of randomness + num_warmup_samples: int = 100, + # evaluation + early_stopping_rounds: int = 50, + verbose_eval: int = 25, + average_peak_n: int = 32, + ): + super().__init__() + # feature extractor + self.extractor = extractor + # model-related + if config.nthread is None: + # use physical core number + config._replace(nthread=cpu_count(logical=False)) + self.config = config + # serialization-related + if path is not None: + self.load(path) + # behavior of randomness + self.num_warmup_samples = num_warmup_samples + # evaluation + self.early_stopping_rounds = early_stopping_rounds + self.verbose_eval = verbose_eval + self.average_peak_n = average_peak_n + # states + self.cached_features = [] + self.cached_mean_costs = np.empty((0,), dtype="float64") + self.cached_normalizer = None + self.booster = None + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + Since XGBoost model trains from scratch, each time we can only load the model without the + previous cached features / results so any call of update won't use previous training data. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + untar(path, tmpdirname) + self.booster.load_model(os.path.join(tmpdirname, "model.bin")) + self.cached_features = list( + np.load(os.path.join(tmpdirname, "cached_features.npy"), allow_pickle=True) + ) + self.cached_mean_costs = np.load( + os.path.join(tmpdirname, "cached_mean_costs.npy"), allow_pickle=True + ) + self.cached_normalizer = np.min(self.cached_mean_costs) + if self.cached_normalizer <= 0: + raise ValueError("The minimum mean cost must be greater than 0!") + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + Since XGBoost model trains from scratch, each time we can only save the model without the + previous cached features / results so any call of update won't use previous training data. + """ + import xgboost as xgb # pylint: disable=import-outside-toplevel + + if self.booster is None: + # save all the paramaters + self.booster = xgb.Booster(self.config.to_dict()) + with tempfile.TemporaryDirectory() as tmpdirname: + self.booster.save_model(os.path.join(tmpdirname, "model.bin")) + np.save( + os.path.join(tmpdirname, "cached_features.npy"), + np.array(self.cached_features, dtype=object), + ) + np.save(os.path.join(tmpdirname, "cached_mean_costs.npy"), self.cached_mean_costs) + tar( + path, + [ + os.path.join(tmpdirname, "model.bin"), + os.path.join(tmpdirname, "cached_features.npy"), + os.path.join(tmpdirname, "cached_mean_costs.npy"), + ], + ) + + def update( + self, + tune_context: "TuneContext", + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + assert len(candidates) == len(results) + if len(candidates) == 0: + return + # extract feature and do validation + new_features = [ + x.numpy().astype("float32") + for x in self.extractor.extract_from(tune_context, candidates) + ] + new_mean_costs = [float(sum(x.run_secs) / len(x.run_secs)) for x in results] + if self.booster is not None and self.cached_normalizer is not None: + logger.debug( + "XGB validation: %s", + "\t".join( + f"{key}: {score:.6f}" + for key, score in self._validate( + xs=new_features, + ys=new_mean_costs, + ) + ), + ) + # use together with previous features + self.cached_features.extend(new_features) + self.cached_mean_costs = np.append(self.cached_mean_costs, new_mean_costs) + self.cached_normalizer = np.min(self.cached_mean_costs) + if self.cached_normalizer <= 0: + raise ValueError("The minimum mean cost must be greater than 0!") + # train xgb model + self._train( + xs=self.cached_features, + ys=self.cached_mean_costs, + ) + + def predict( + self, tune_context: "TuneContext", candidates: List[MeasureCandidate] + ) -> np.ndarray: + """Predict the normalized score using the cost model. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted normalized score. + """ + n_measured = len(self.cached_features) + if self.booster is not None and n_measured >= self.num_warmup_samples: + features = self.extractor.extract_from(tune_context, candidates) + ret = self._predict(xs=[x.numpy().astype("float32") for x in features]) + else: + ret = np.random.uniform( + low=0, + high=1, + size=(len(candidates),), + ) + return ret.astype("float64") + + def _train( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ys: List[float], + ) -> None: + import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel + + self.d_train = PackSum( + xs=xs, + ys=self.cached_normalizer / ys, + ) + + def obj(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return self.d_train.obj_square_error(ys_pred) + + def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return self.d_train.rmse(ys_pred) + + def average_peak_score( + ys_pred: np.ndarray, d_train: "xgb.DMatrix" # type: ignore # pylint: disable = unused-argument + ): + return self.d_train.average_peak_score(ys_pred, self.average_peak_n) + + self.booster = xgb.train( + self.config.to_dict(), + self.d_train.dmatrix, + num_boost_round=10000, + obj=obj, + callbacks=[ + custom_callback( + early_stopping_rounds=self.early_stopping_rounds, + verbose_eval=self.verbose_eval, + fevals=[ + rmse, + average_peak_score, + ], + evals=[(self.d_train.dmatrix, "tr")], + ) + ], + ) + + del self.d_train + # todo(zxybazh): measure callback to save the model + + def _predict( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ) -> np.ndarray: + d_test = PackSum(xs=xs, ys=None) + pred = self.booster.predict(d_test.dmatrix) + ret = d_test.predict_with_score(pred) + return ret + + def _validate( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ys: List[float], + ) -> List[Tuple[str, float]]: + """Evaluate the score of inputs. + + Parameters + ---------- + xs : List[np.ndarray] + A batch of input samples + ys : List[float] + A batch of labels + + Returns + ------- + scores: np.ndarray + The predicted result for all inputs. + """ + if self.booster is None or self.cached_normalizer is None: + return [] + + d_valid = PackSum( + xs=xs, + ys=self.cached_normalizer / ys, + ) + + def average_peak_score(ys_pred: np.ndarray): + return d_valid.average_peak_score(ys_pred, n=self.average_peak_n) + + ys_pred = self.booster.predict(d_valid.dmatrix) + eval_result: List[Tuple[str, float]] = [ + feval(ys_pred) + for feval in ( + average_peak_score, + d_valid.rmse, + ) + ] + eval_result.sort(key=make_metric_sorter("p-rmse")) + return eval_result + + +def custom_callback( + early_stopping_rounds: int, + verbose_eval: int, + fevals: List[Callable], + evals: List[Tuple["xgb.DMatrix", str]], + focused_metric: str = "tr-p-rmse", +): + """Callback function for xgboost to support multiple custom evaluation functions""" + sort_key = make_metric_sorter(focused_metric=focused_metric) + + state = {} + + def init(env: "xgb.core.CallbackEnv"): + """Internal function""" + booster: "xgb.Booster" = env.model + + state["best_iteration"] = 0 + state["best_score"] = float("inf") + if booster is None: + assert env.cvfolds is not None + return + if booster.attr("best_score") is not None: + state["best_score"] = float(booster.attr("best_score")) + state["best_iteration"] = int(booster.attr("best_iteration")) + state["best_msg"] = booster.attr("best_msg") + else: + booster.set_attr(best_iteration=str(state["best_iteration"])) + booster.set_attr(best_score=str(state["best_score"])) + + def callback(env: "xgb.core.CallbackEnv"): + # pylint:disable = import-outside-toplevel + import xgboost as xgb + from xgboost.callback import _fmt_metric + from xgboost.core import EarlyStopException + + try: + from xgboost.training import aggcv + except ImportError: + from xgboost.callback import _aggcv as aggcv + # pylint:enable = import-outside-toplevel + + if not state: + init(env) + booster: xgb.Booster = env.model + iteration: int = env.iteration + cvfolds: List[xgb.training.CVPack] = env.cvfolds + ##### Evaluation ##### + # `eval_result` is a list of (key, score) + eval_result: List[Tuple[str, float]] = [] + if cvfolds is None: + eval_result = itertools_chain.from_iterable( + [ + (key, float(value)) + for key, value in map( + lambda x: x.split(":"), + booster.eval_set( + evals=evals, + iteration=iteration, + feval=feval, + ).split()[1:], + ) + ] + for feval in fevals + ) + else: + eval_result = itertools_chain.from_iterable( + [ + (key, score) + for key, score, _std in aggcv( + fold.eval( + iteration=iteration, + feval=feval, + ) + for fold in cvfolds + ) + ] + for feval in fevals + ) + eval_result = list(eval_result) + eval_result.sort(key=sort_key) + + ##### Print eval result ##### + if verbose_eval and iteration % verbose_eval == 0: + info = [] + for key, score in eval_result: + if "null" not in key: + info.append(f"{key}: {score:.6f}") + logger.debug("XGB iter %3d: %s", iteration, "\t".join(info)) + + ##### Choose score and do early stopping ##### + score = None + for key, _score in eval_result: + if key == focused_metric: + score = _score + break + assert score is not None + + best_score = state["best_score"] + best_iteration = state["best_iteration"] + if score < best_score: + tab = "\t" # to work with f-string + msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}" + state["best_msg"] = msg + state["best_score"] = score + state["best_iteration"] = env.iteration + # save the property to attributes, so they will occur in checkpoint. + if env.model is not None: + env.model.set_attr( + best_score=str(state["best_score"]), + best_iteration=str(state["best_iteration"]), + best_msg=state["best_msg"], + ) + elif env.iteration - best_iteration >= early_stopping_rounds: + best_msg = state["best_msg"] + if verbose_eval and env.rank == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + raise EarlyStopException(best_iteration) + + return callback diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py index ffe7655a51..83ac7426cc 100644 --- a/python/tvm/meta_schedule/feature_extractor/__init__.py +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -21,3 +21,4 @@ """ from .feature_extractor import FeatureExtractor, PyFeatureExtractor from .per_store_feature import PerStoreFeature +from .random_feature_extractor import RandomFeatureExtractor diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index 4126e0fb45..bd7656e5be 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule FeatureExtractor.""" - from typing import List from tvm._ffi import register_object -from tvm.runtime import Object, NDArray +from tvm.runtime import Object +from tvm.runtime.ndarray import NDArray from .. import _ffi_api from ..utils import _get_hex_address, check_override @@ -46,11 +46,12 @@ def extract_from( Returns ------- features : List[NDArray] - The feature ndarray extracted. + The feature numpy ndarray extracted. """ - return _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member + result = _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member self, tune_context, candidates ) + return result @register_object("meta_schedule.PyFeatureExtractor") @@ -64,7 +65,8 @@ def __init__(self): def f_extract_from( tune_context: TuneContext, candidates: List[MeasureCandidate] ) -> List[NDArray]: - return self.extract_from(tune_context, candidates) + features = self.extract_from(tune_context, candidates) + return features def f_as_string() -> str: return str(self) diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py new file mode 100644 index 0000000000..f9f2f287fd --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Random Feature Extractor.""" +from typing import List, Union, Tuple + +import numpy as np +from tvm.runtime.ndarray import NDArray, array + +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..feature_extractor import PyFeatureExtractor + + +class RandomFeatureExtractor(PyFeatureExtractor): + """Random Feature Extractor + + Parameters + ---------- + feature_size : int + The size of each block's feature vector. + max_block_num : int + The maximum number of blocks in each schedule. + random_state : Union[Tuple[str, np.ndarray, int, int, float], dict] + The current random state of the f + """ + + feature_size: int + max_block_num: int + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + + def __init__(self, *, feature_size: int = 30, max_block_num: int = 5, seed=0): + super().__init__() + assert max_block_num >= 1, "Max block number must be greater or equal to one!" + self.max_block_num = max_block_num + self.feature_size = feature_size + np.random.seed(seed) + self.random_state = np.random.get_state() + + def extract_from( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> List[NDArray]: + np.random.set_state(self.random_state) + result = [ + np.random.rand(np.random.randint(1, self.max_block_num + 1), self.feature_size) + for candidate in candidates + ] + self.random_state = np.random.get_state() + return [array(x) for x in result] diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index e92bbbefca..25a03aaf87 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -22,7 +22,7 @@ from tvm._ffi import register_object from tvm.runtime import Object -from tvm.tir.schedule import Schedule, Trace +from tvm.tir.schedule import Schedule from .. import _ffi_api from ..arg_info import ArgInfo diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 728733ff13..64a2965479 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. """Utilities for meta schedule""" +from typing import Any, Callable, List, Optional, Union + import ctypes import json import os import shutil -from typing import Any, Callable, List, Optional, Union - import psutil # type: ignore + import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError @@ -32,7 +33,7 @@ @register_func("meta_schedule.cpu_count") -def cpu_count(logical: bool = True) -> int: +def _cpu_count_impl(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system Parameters @@ -60,6 +61,22 @@ def cpu_count(logical: bool = True) -> int: return psutil.cpu_count(logical=logical) or 1 +def cpu_count(logical: bool = True) -> int: + """Return the number of logical or physical CPUs in the system + + Parameters + ---------- + logical : bool = True + If True, return the number of logical CPUs, otherwise return the number of physical CPUs + + Returns + ------- + cpu_count : int + The number of logical or physical CPUs in the system + """ + return _cpu_count_impl(logical) + + def get_global_func_with_default_on_worker( name: Union[None, str, Callable], default: Callable, diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index e061d1af40..c5871a53eb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -365,7 +365,7 @@ def sample_compute_location( decision: Optional[int] = None, ) -> LoopRV: """Sample a compute-at location on a BlockRV so that its producer can compute at that loop - + Parameters ---------- block : BlockRV diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 34c2684e02..4eca068e17 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -104,17 +104,17 @@ bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { } else if (ann.first == attr::meta_schedule_vectorize) { found = true; if (const auto* imm = ann.second.as()) { - parsed->max_vectorize_extent = imm->value;; + parsed->max_vectorize_extent = imm->value; } } else if (ann.first == attr::meta_schedule_unroll_explicit) { found = true; if (const auto* imm = ann.second.as()) { - parsed->unroll_explicit = imm->value;; + parsed->unroll_explicit = imm->value; } } else if (ann.first == attr::meta_schedule_unroll_implicit) { found = true; if (const auto* imm = ann.second.as()) { - parsed->unroll_implicit = imm->value;; + parsed->unroll_implicit = imm->value; } } } diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index d1ce14e5a4..cdc72d30b6 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -16,17 +16,22 @@ # under the License. from typing import List +import tempfile +import os import re +import sys +import shutil +import pytest import numpy as np import tvm from tvm.script import tir as T -from tvm.meta_schedule import TuneContext +from tvm.tir.schedule.schedule import Schedule from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.runner import RunnerResult -from tvm.tir.schedule.schedule import Schedule -from tvm.meta_schedule.cost_model import PyCostModel - +from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor +from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.tune_context import TuneContext # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @tvm.script.ir_module @@ -50,11 +55,11 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s def test_meta_schedule_cost_model(): class FancyCostModel(PyCostModel): - def load(self, file_location: str) -> bool: - return True + def load(self, path: str) -> None: + pass - def save(self, file_location: str) -> bool: - return True + def save(self, path: str) -> None: + pass def update( self, @@ -70,8 +75,8 @@ def predict( return np.random.rand(10) model = FancyCostModel() - assert model.save("fancy_test_location") - assert model.load("fancy_test_location") + model.save("fancy_test_location") + model.load("fancy_test_location") model.update(TuneContext(), [], []) results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])]) assert results.shape == (10,) @@ -79,11 +84,11 @@ def predict( def test_meta_schedule_cost_model_as_string(): class NotSoFancyCostModel(PyCostModel): - def load(self, file_location: str) -> bool: - return True + def load(self, path: str) -> None: + pass - def save(self, file_location: str) -> bool: - return True + def save(self, path: str) -> None: + pass def update( self, @@ -103,6 +108,113 @@ def predict( assert pattern.match(str(cost_model)) +def test_meta_schedule_random_model(): + model = RandomModel() + model.update(TuneContext(), [], []) + res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(10)]) + assert len(res) == 10 + assert min(res) >= 0 and max(res) <= model.max_range + + +def test_meta_schedule_random_model_reseed(): + model = RandomModel(seed=100) + res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)]) + new_model = RandomModel(seed=100) + new_res = new_model.predict( + TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)] + ) + assert (res == new_res).all() + + +def test_meta_schedule_random_model_reload(): + model = RandomModel(seed=25973) + model.predict( + TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(30)] + ) # change state + path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_random_model.npy") + model.save(path) + res1 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) + model.load(path) + res2 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)]) + shutil.rmtree(os.path.dirname(path)) + assert (res1 == res2).all() + + +def _dummy_candidate(): + return MeasureCandidate(Schedule(Matmul), []) + + +def _dummy_result(num_samples: int = 4, max_run_sec: int = 10): + return RunnerResult(list(np.random.rand(num_samples) * max_run_sec + 1e-6), None) + + +def test_meta_schedule_xgb_model(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=2) + update_sample_count = 10 + predict_sample_count = 100 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + + +def test_meta_schedule_xgb_model_reload(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=10) + update_sample_count = 20 + predict_sample_count = 30 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + random_state = model.extractor.random_state # save feature extractor's random state + path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_xgb_model.bin") + cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) + model.save(path) + res1 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + model.extractor.random_state = random_state # load feature extractor's random state + model.cached_features = None + model.cached_mean_costs = None + model.load(path) + new_cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) + res2 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + shutil.rmtree(os.path.dirname(path)) + assert (res1 == res2).all() + # cached feature does not change + assert len(cached[0]) == len(new_cached[0]) + for i in range(len(cached[0])): + assert (cached[0][i] == new_cached[0][i]).all() + # cached meaen cost does not change + assert (cached[1] == new_cached[1]).all() + + +def test_meta_schedule_xgb_model_reupdate(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=2) + update_sample_count = 60 + predict_sample_count = 100 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + + if __name__ == "__main__": - test_meta_schedule_cost_model() - test_meta_schedule_cost_model_as_string() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py index 05b2bae40b..4f068d7a83 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -17,10 +17,9 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from typing import List -import numpy as np import re +import numpy as np -from tvm.runtime import NDArray from tvm.meta_schedule import TuneContext from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.feature_extractor import PyFeatureExtractor @@ -29,8 +28,10 @@ def test_meta_schedule_feature_extractor(): class FancyFeatureExtractor(PyFeatureExtractor): def extract_from( - self, tune_context: TuneContext, candidates: List[MeasureCandidate] - ) -> List[NDArray]: + self, + tune_context: TuneContext, # pylint: disable = unused-argument + candidates: List[MeasureCandidate], # pylint: disable = unused-argument + ) -> List[np.ndarray]: return [np.random.rand(4, 5)] extractor = FancyFeatureExtractor() @@ -42,9 +43,11 @@ def extract_from( def test_meta_schedule_feature_extractor_as_string(): class NotSoFancyFeatureExtractor(PyFeatureExtractor): def extract_from( - self, tune_context: TuneContext, candidates: List[MeasureCandidate] - ) -> List[NDArray]: - return None + self, + tune_context: TuneContext, # pylint: disable = unused-argument + candidates: List[MeasureCandidate], # pylint: disable = unused-argument + ) -> List[np.ndarray]: + return [] feature_extractor = NotSoFancyFeatureExtractor() pattern = re.compile(r"NotSoFancyFeatureExtractor\(0x[a-f|0-9]*\)") diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index 8a215e8859..b7f8f507d3 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -46,9 +46,6 @@ def main(a: T.handle, b: T.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] -# fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable - @T.prim_func def Move_PUV0(a: T.handle, b: T.handle) -> None: @@ -74,6 +71,9 @@ def Move_PUV0(a: T.handle, b: T.handle) -> None: T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): postproc = RewriteParallelVectorizeUnroll()