Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Meta Schedule] Add XGBoost Model & Random Model #519

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions include/tvm/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,15 @@ class CostModelNode : public runtime::Object {

/*!
* \brief Load the cost model from given file location.
* \param file_location The file location.
* \return Whether cost model was loaded successfully.
* \param path The file path.
*/
virtual bool Load(const String& file_location) = 0;
virtual void Load(const String& path) = 0;

/*!
* \brief Save the cost model to given file location.
* \param file_location The file location.
* \return Whether cost model was saved successfully.
* \param path The file path.
*/
virtual bool Save(const String& file_location) = 0;
virtual void Save(const String& path) = 0;

/*!
* \brief Update the cost model given running results.
Expand Down Expand Up @@ -78,16 +76,14 @@ class PyCostModelNode : public CostModelNode {
public:
/*!
* \brief Load the cost model from given file location.
* \param file_location The file location.
* \return Whether cost model was loaded successfully.
* \param path The file path.
*/
using FLoad = runtime::TypedPackedFunc<bool(String)>;
using FLoad = runtime::TypedPackedFunc<void(String)>;
/*!
* \brief Save the cost model to given file location.
* \param file_location The file location.
* \return Whether cost model was saved successfully.
* \param path The file path.
*/
using FSave = runtime::TypedPackedFunc<bool(String)>;
using FSave = runtime::TypedPackedFunc<void(String)>;
/*!
* \brief Update the cost model given running results.
* \param tune_context The tuning context.
Expand Down Expand Up @@ -130,16 +126,15 @@ class PyCostModelNode : public CostModelNode {
// `f_as_string` is not visited
}

bool Load(const String& file_location) {
void Load(const String& path) {
ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!";
return f_load(file_location);
f_load(path);
}

bool Save(const String& file_location) {
void Save(const String& path) {
ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!";
return f_save(file_location);
f_save(path);
}

void Update(const TuneContext& tune_context, const Array<MeasureCandidate>& candidates,
const Array<RunnerResult>& results) {
ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!";
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/cost_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
The tvm.meta_schedule.cost_model package.
"""
from .cost_model import CostModel, PyCostModel
from .random_model import RandomModel
from .xgb_model import XGBModel
38 changes: 14 additions & 24 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,25 @@
class CostModel(Object):
"""Cost model."""

def load(self, file_location: str) -> bool:
def load(self, path: str) -> None:
"""Load the cost model from given file location.

Parameters
----------
file_location : str
The file location.

Return
------
result : bool
Whether cost model was loaded successfully.
path : str
The file path.
"""
return bool(_ffi_api.CostModelLoad(self, file_location)) # type: ignore # pylint: disable=no-member
_ffi_api.CostModelLoad(self, path) # type: ignore # pylint: disable=no-member

def save(self, file_location: str) -> bool:
def save(self, path: str) -> None:
"""Save the cost model to given file location.

Parameters
----------
file_location : str
The file location.

Return
------
result : bool
Whether cost model was saved successfully.
path : str
The file path.
"""
return bool(_ffi_api.CostModelSave(self, file_location)) # type: ignore # pylint: disable=no-member
_ffi_api.CostModelSave(self, path) # type: ignore # pylint: disable=no-member

def update(
self,
Expand Down Expand Up @@ -96,7 +86,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate])

Return
------
result : bool
result : np.ndarray
The predicted running results.
"""
n = len(candidates)
Expand All @@ -118,20 +108,20 @@ def __init__(self):
"""Constructor."""

@check_override(self.__class__, CostModel)
def f_load(file_location: str) -> bool:
return self.load(file_location)
def f_load(path: str) -> None:
self.load(path)

@check_override(self.__class__, CostModel)
def f_save(file_location: str) -> bool:
return self.save(file_location)
def f_save(path: str) -> None:
self.save(path)

@check_override(self.__class__, CostModel)
def f_update(
self,
tune_context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> bool:
) -> None:
self.update(tune_context, candidates, results)

@check_override(self.__class__, CostModel)
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/meta_schedule/cost_model/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Cost model metrics for meta schedule"""
from typing import List
import numpy as np


def max_curve(trial_scores: np.ndarray) -> List[float]:
"""f(n) = max([s[i] fo i < n])

Parameters
----------
trial_scores : List[float]
the score of i-th trial

Returns
-------
curve : List[float]
function values
"""
ret = np.empty(len(trial_scores))
keep = -1e9
for i, score in enumerate(trial_scores):
keep = max(keep, score)
ret[i] = keep
return ret
123 changes: 123 additions & 0 deletions python/tvm/meta_schedule/cost_model/random_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Random cost model
"""
from typing import List, Union, Tuple, Optional

import numpy as np

from ..runner import RunnerResult
from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate
from ..cost_model import PyCostModel


class RandomModel(PyCostModel):
"""Random cost model

Parameters
----------
random_state : Union[Tuple[str, np.ndarray, int, int, float], dict]
The random state of the random number generator.
path : Optional[str]
The path of the random cost model.
max_range : Optional[int]
The maximum range of random results, [0, max_range].

Reference
---------
https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html
"""

random_state: Union[Tuple[str, np.ndarray, int, int, float], dict]
path: Optional[str]

def __init__(
self,
*,
seed: Optional[int] = None,
path: Optional[str] = None,
max_range: Optional[int] = 100,
):
super().__init__()
if path is not None:
self.load(path)
else:
np.random.seed(seed)
self.random_state = np.random.get_state()
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
self.max_range = max_range

def load(self, path: str) -> None:
"""Load the cost model from given file location.

Parameters
----------
path : str
The file path.
"""
self.random_state = tuple(np.load(path, allow_pickle=True))

def save(self, path: str) -> None:
"""Save the cost model to given file location.

Parameters
----------
path : str
The file path.
"""
np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True)

def update(
self,
tune_context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
) -> None:
"""Update the cost model given running results.

Parameters
----------
tune_context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.
results : List[RunnerResult]
The running results of the measure candidates.
"""

def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
"""Update the cost model given running results.

Parameters
----------
tune_context : TuneContext,
The tuning context.
candidates : List[MeasureCandidate]
The measure candidates.

Return
------
result : np.ndarray
The predicted running results.
"""
np.random.set_state(self.random_state)
# todo(@zxybazh): Use numpy's RandState object:
# https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState
result = np.random.rand(len(candidates)) * self.max_range
self.random_state = np.random.get_state()
return result
Loading