Skip to content

Commit

Permalink
Checkpoint.
Browse files Browse the repository at this point in the history
Fix cost model comment.

Finish evolutionary seaarch.

Remove  extra code.

Fix compile.

Add comments.

Add python part.

Ad test.

Update other files & comments.
  • Loading branch information
junrushao committed Nov 29, 2021
1 parent 9a3b851 commit a244e90
Show file tree
Hide file tree
Showing 13 changed files with 1,038 additions and 21 deletions.
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class CostModelNode : public runtime::Object {
const Array<RunnerResult>& results) = 0;

/*!
* \brief Predict the running results of given measure candidates.
* \brief Predict the normalized score (the larger the better) of given measure candidates.
* \param tune_context The tuning context.
* \param candidates The measure candidates.
* \return The predicted running results.
* \return The predicted normalized score.
*/
virtual std::vector<double> Predict(const TuneContext& tune_context,
const Array<MeasureCandidate>& candidates) = 0;
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_

#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/mutator.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/tir/schedule/schedule.h>
Expand All @@ -29,6 +31,7 @@ namespace meta_schedule {

// Forward declaration
class TuneContext;
class CostModel;

/*! \brief The schedule (with input shapes) to be measured. */
class MeasureCandidateNode : public runtime::Object {
Expand Down Expand Up @@ -255,6 +258,17 @@ class SearchStrategy : public runtime::ObjectRef {
*/
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);

TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
int num_trials_total, //
int population, //
double init_measured_ratio, //
int genetic_algo_iters, //
double p_mutate, //
double eps_greedy, //
Map<Mutator, FloatImm> mutator_probs, //
Database database, //
CostModel cost_model);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
};

Expand Down
3 changes: 1 addition & 2 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate])
Return
------
result : bool
The predicted running results.
The predicted normalized score.
"""
n = len(candidates)
results = np.zeros(shape=(n,), dtype="float64")
Expand Down Expand Up @@ -127,7 +127,6 @@ def f_save(file_location: str) -> bool:

@check_override(self.__class__, CostModel)
def f_update(
self,
tune_context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult],
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
from .replay_trace import ReplayTrace
from .replay_func import ReplayFunc
from .evolutionary_search import EvolutionarySearch
101 changes: 101 additions & 0 deletions python/tvm/meta_schedule/search_strategy/evolutionary_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Evolutionary Search Strategy"""

from typing import TYPE_CHECKING, Dict

from tvm._ffi import register_object
from ...tir import FloatImm

from .search_strategy import SearchStrategy
from ..mutator import Mutator
from ..database import Database

from .. import _ffi_api

if TYPE_CHECKING:
from ..cost_model import CostModel


@register_object("meta_schedule.EvolutionarySearch")
class EvolutionarySearch(SearchStrategy):
"""
Replay Trace Search Strategy is a search strategy that always replays the trace by removing its
decisions so that the decisions would be randomly re-generated.
Parameters
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
Total number of trials.
population : int
The initial population of traces from measured samples and randomly generated samples.
init_measured_ratio : int
The ratio of measured samples in the initial population.
genetic_algo_iters : int
The number of iterations for genetic algorithm.
p_mutate : float
The probability of mutation.
eps_greedy : float
The ratio of greedy selected samples in the final picks.
mutator_probs: Dict[Mutator, FloatImm]
The probability contribution of all mutators.
database : Database
The database used in the search.
cost_model : CostModel
The cost model used in the search.
"""

num_trials_per_iter: int
num_trials_total: int
population: int
init_measured_ratio: int
genetic_algo_iters: int
p_mutate: float
eps_greedy: float
mutator_probs: Dict[Mutator, FloatImm]
database: Database
cost_model: "CostModel"

def __init__(
self,
num_trials_per_iter: int,
num_trials_total: int,
population: int,
init_measured_ratio: float,
genetic_algo_iters: int,
p_mutate: float,
eps_greedy: float,
mutator_probs: Dict[Mutator, FloatImm],
database: Database,
cost_model: "CostModel",
):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyEvolutionarySearch, # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
population,
init_measured_ratio,
genetic_algo_iters,
p_mutate,
eps_greedy,
mutator_probs,
database,
cost_model,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.tir.schedule import Schedule, Trace
from tvm.tir.schedule import Schedule

from .. import _ffi_api
from ..arg_info import ArgInfo
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def sample_compute_location(
decision: Optional[int] = None,
) -> LoopRV:
"""Sample a compute-at location on a BlockRV so that its producer can compute at that loop
Parameters
----------
block : BlockRV
Expand Down
Loading

0 comments on commit a244e90

Please sign in to comment.