Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Dec 12, 2021
1 parent 972c45a commit d753e87
Show file tree
Hide file tree
Showing 23 changed files with 594 additions and 427 deletions.
22 changes: 10 additions & 12 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,24 +265,22 @@ class SearchStrategy : public runtime::ObjectRef {
* \brief Constructor of evolutionary search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for evolutionary search.
* \param population The initial sample population.
* \param max_replay_fail_cnt The maximum number to fail trace replaying.
* \param population_size The initial sample population.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param genetic_algo_iters The iterations to run the genetic algorithm.
* \param max_evolve_fail_cnt The maximum number to try evolving the given trace.
* \param p_mutate The probability of mutation.
* \param init_max_fail_count The maximum number to fail trace replaying.
* \param genetic_num_iters The iterations to run the genetic algorithm.
* \param genetic_mutate_prob The probability of mutation.
* \param genetic_max_fail_count The maximum number to try evolving the given trace.
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
* \param database The database to use.
* \param cost_model The cost model to use.
*/
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
int num_trials_total, //
int population, //
int max_replay_fail_cnt, //
int population_size, //
double init_measured_ratio, //
int genetic_algo_iters, //
int max_evolve_fail_cnt, //
double p_mutate, //
int init_max_fail_count, //
int genetic_num_iters, //
double genetic_mutate_prob, //
int genetic_max_fail_count, //
double eps_greedy);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _worker_func(


@register_func("meta_schedule.builder.default_build")
def default_build(mod: IRModule, target: Target, params: Optional[Dict[str, NDArray]]) -> Module:
def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]) -> Module:
"""Default build function.
Parameters
Expand All @@ -222,7 +222,7 @@ def default_build(mod: IRModule, target: Target, params: Optional[Dict[str, NDAr
The IRModule to be built.
target : Target
The target to be built.
params : Optional[Dict[str, NDArray]]
_params : Optional[Dict[str, NDArray]]
The parameters to be used for the build. Must be None.
Returns
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def update(
def _mean_cost(x: RunnerResult) -> float:
if not x.run_secs:
return 1e10
return sum(float(s) for s in x.run_secs) / len(x.run_secs)
return float(np.median([float(s) for s in x.run_secs]))

new_features = [
x.numpy().astype("float32")
Expand Down
72 changes: 36 additions & 36 deletions python/tvm/meta_schedule/search_strategy/evolutionary_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,56 +36,56 @@ class EvolutionarySearch(SearchStrategy):
Number of trials per iteration.
num_trials_total : int
Total number of trials.
population : int
population_size : int
The initial population of traces from measured samples and randomly generated samples.
max_replay_fail_cnt : int
The maximum number to fail trace replaying.
init_measured_ratio : int
The ratio of measured samples in the initial population.
genetic_algo_iters : int
init_max_fail_count : int
The maximum number to fail trace replaying.
genetic_num_iters : int
The number of iterations for genetic algorithm.
max_evolve_fail_cnt : int
The maximum number to retry mutation.
p_mutate : float
genetic_mutate_prob : float
The probability of mutation.
genetic_max_fail_count : int
The maximum number to retry mutation.
eps_greedy : float
The ratio of greedy selected samples in the final picks.
"""

num_trials_per_iter: int
num_trials_total: int
population: int
population_size: int
init_measured_ratio: int
genetic_algo_iters: int
max_replay_fail_cnt: int
max_evolve_fail_cnt: int
p_mutate: float
init_max_fail_count: int
genetic_num_iters: int
genetic_mutate_prob: float
genetic_max_fail_count: int
eps_greedy: float

def __init__(
self,
*,
num_trials_per_iter: int,
num_trials_total: int,
population: int = 2048,
max_replay_fail_cnt: int = 64,
init_measured_ratio: float = 0.2,
genetic_algo_iters: int = 10,
max_evolve_fail_cnt: int = 10,
p_mutate: float = 0.85,
eps_greedy: float = 0.25,
):
population_size: int,
init_measured_ratio: float,
init_max_fail_count: int,
genetic_num_iters: int,
genetic_mutate_prob: float,
genetic_max_fail_count: int,
eps_greedy: float,
) -> None:
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
population,
max_replay_fail_cnt,
population_size,
init_measured_ratio,
genetic_algo_iters,
max_evolve_fail_cnt,
p_mutate,
init_max_fail_count,
genetic_num_iters,
genetic_mutate_prob,
genetic_max_fail_count,
eps_greedy,
)

Expand All @@ -95,23 +95,23 @@ class EvolutionarySearchConfig(NamedTuple):

num_trials_per_iter: int
num_trials_total: int
population: int = 2048
max_replay_fail_cnt: int = 64
population_size: int = 2048
init_measured_ratio: float = 0.2
genetic_algo_iters: int = 10
max_evolve_fail_cnt: int = 10
p_mutate: float = 0.85
eps_greedy: float = 0.25
init_max_fail_count: int = 64
genetic_num_iters: int = 4
genetic_mutate_prob: float = 0.85
genetic_max_fail_count: int = 10
eps_greedy: float = 0.05

def create_strategy(self) -> EvolutionarySearch:
return EvolutionarySearch(
num_trials_per_iter=self.num_trials_per_iter,
num_trials_total=self.num_trials_total,
population=self.population,
max_replay_fail_cnt=self.max_replay_fail_cnt,
population_size=self.population_size,
init_measured_ratio=self.init_measured_ratio,
genetic_algo_iters=self.genetic_algo_iters,
max_evolve_fail_cnt=self.max_evolve_fail_cnt,
p_mutate=self.p_mutate,
init_max_fail_count=self.init_max_fail_count,
genetic_num_iters=self.genetic_num_iters,
genetic_mutate_prob=self.genetic_mutate_prob,
genetic_max_fail_count=self.genetic_max_fail_count,
eps_greedy=self.eps_greedy,
)
40 changes: 33 additions & 7 deletions src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,31 @@ inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x +

namespace utils {

/*!
* \brief Get the shape of the buffer
* \param buffer The buffer
* \param analyzer The analyzer
* \return The shape of the buffer
*/
std::vector<int64_t> GetBufferShape(const Buffer& buffer, arith::Analyzer* analyzer) {
int ndim = buffer->shape.size();
std::vector<int64_t> result;
result.reserve(ndim);
for (const PrimExpr& i : buffer->shape) {
if (const IntImmNode* int_imm = i.as<IntImmNode>()) {
result.push_back(int_imm->value);
continue;
}
arith::ConstIntBound bound = analyzer->const_int_bound(i);
if (0 <= bound->max_value && bound->max_value < arith::ConstIntBound::kPosInf) {
result.push_back(bound->max_value);
} else {
result.push_back(1);
}
}
return result;
}

/*!
* \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if it exists
* \param loop The loop to be checked
Expand Down Expand Up @@ -655,7 +680,7 @@ struct Feature {

static void Pad(std::vector<double>* v) { v->insert(v->end(), 18, 0.0); }

void SetStride(const LoopNest& loop_nest);
void SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer);

void SetReuse(const LoopNest& loop_nest, //
int64_t top_loop_touch_bytes, //
Expand Down Expand Up @@ -775,13 +800,13 @@ void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes,
}
}

void Feature::SubFeature::SetStride(const LoopNest& loop_nest) {
void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer) {
int n_loops = loop_nest.loops.size();
const std::vector<const ForNode*>& loops = loop_nest.loops;
// For each buffer, we find the loop stride on it
const BufferNode* buffer = this->buffer;
int ndim = this->buffer->shape.size();
IntVec buffer_shape = support::AsVector<PrimExpr, int64_t>(buffer->shape);
IntVec buffer_shape = utils::GetBufferShape(GetRef<Buffer>(buffer), analyzer);
// Calculate the buffer's stride from its shape
IntVec buffer_stride(ndim);
if (ndim >= 1) {
Expand Down Expand Up @@ -934,7 +959,7 @@ Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_
this->SetRegion(loop_nest, for_touched_bytes, buffer_touched_under_loop, analyzer);
// Step 2. Calculate stride-related feature
for (auto& feature : sub_features) {
feature.SetStride(loop_nest);
feature.SetStride(loop_nest, analyzer);
}
// Step 3. Calculate reuse-related feature
int64_t top_loop_touch_bytes = 0.0;
Expand Down Expand Up @@ -1056,9 +1081,10 @@ struct Feature {

Feature() = default;

explicit Feature(const LoopNest& loop_nest, const Buffer& buffer) {
explicit Feature(const LoopNest& loop_nest, const Buffer& buffer, arith::Analyzer* analyzer) {
std::vector<int64_t> shape = utils::GetBufferShape(buffer, analyzer);
int64_t numel = 1;
for (int64_t x : support::AsVector<PrimExpr, int64_t>(buffer->shape)) {
for (int64_t x : shape) {
numel *= x;
}
alloc_size = numel * buffer->dtype.bytes();
Expand Down Expand Up @@ -1181,7 +1207,7 @@ class PerStoreFeatureCollector : private StmtVisitor {

void HandleBufferAlloc(const Buffer& buffer) {
Feature& feature = buffer_features_[buffer.get()];
feature.group4 = std::make_unique<group4::Feature>(loop_nest_, buffer);
feature.group4 = std::make_unique<group4::Feature>(loop_nest_, buffer, &analyzer_);
}

explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes,
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
sch->Vectorize(split[3]);
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
sch->StorageAlign(block, 0, -2, 32, 8);
}
sch->StorageAlign(block, 0, -2, 32, 8);
});
} else {
tasks.push_back(
Expand All @@ -162,8 +162,8 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
Integer(thread_extent_x)});
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
sch->StorageAlign(block, 0, -2, 32, 8);
}
sch->StorageAlign(block, 0, -2, 32, 8);
});
}
}
Expand Down
Loading

0 comments on commit d753e87

Please sign in to comment.