Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Dec 12, 2021
1 parent 5d4c5b4 commit 526b92d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 37 deletions.
25 changes: 15 additions & 10 deletions src/meta_schedule/search_strategy/evolutionary_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,15 +422,17 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int nu
measured_traces.push_back(record->trace);
}
int actual_num = measured_traces.size();
ThreadedTraceApply pp(self->postprocs_);
std::vector<Schedule> results(actual_num, Schedule{nullptr});
auto f_proc_measured = [this, &measured_traces, &results](int thread_id, int trace_id) -> void {
auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id,
int trace_id) -> void {
PerThreadData& data = self->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
const IRModule& mod = data.mod;
tir::Trace trace = measured_traces.at(trace_id);
Schedule& result = results.at(trace_id);
ICHECK(!result.defined());
if (Optional<Schedule> sch = ApplyTrace(mod, trace, rand_state, self->postprocs_)) {
if (Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
result = sch.value();
} else {
LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace;
Expand All @@ -442,8 +444,9 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int nu
}

std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int num) {
ThreadedTraceApply pp(self->postprocs_);
std::vector<Schedule> results(num, Schedule{nullptr});
auto f_proc_unmeasured = [this, &results](int thread_id, int trace_id) -> void {
auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void {
PerThreadData& data = self->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
const IRModule& mod = data.mod;
Expand All @@ -452,16 +455,18 @@ std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int nu
for (int fail_count = 0; fail_count <= self->init_max_fail_count; ++fail_count) {
int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size());
tir::Trace trace(design_spaces[design_space_index]->insts, {});
if (Optional<Schedule> sch = ApplyTrace(mod, trace, rand_state, self->postprocs_)) {
if (Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
result = sch.value();
break;
}
}
if (!result.defined()) {
LOG(FATAL) << "Sample-Init-Population failed over the maximum limit!";
LOG(FATAL) << "Sample-Init-Population failed over the maximum limit! Summary:\n"
<< pp.SummarizeFailures();
}
};
support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured);
LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures();
return results;
}

Expand Down Expand Up @@ -495,11 +500,12 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
for (PerThreadData& data : self->per_thread_data_) {
data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_);
}
ThreadedTraceApply pp(self->postprocs_);
ConcurrentBitmask cbmask(self->population_size);
std::vector<Schedule> next_population(self->population_size, Schedule{nullptr});
// The worker function
auto f_find_candidate = [&cbmask, &population, &next_population, this](int thread_id,
int trace_id) {
auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id,
int trace_id) {
// Prepare samplers
PerThreadData& data = self->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
Expand All @@ -516,8 +522,7 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
// Decision: mutate
Mutator mutator = opt_mutator.value();
if (Optional<tir::Trace> new_trace = mutator->Apply(trace, rand_state)) {
if (Optional<Schedule> sch =
ApplyTrace(mod, new_trace.value(), rand_state, self->postprocs_)) {
if (Optional<Schedule> sch = pp.Apply(mod, new_trace.value(), rand_state)) {
// note that sch's trace is different from new_trace
// because it contains post-processing information
result = sch.value();
Expand All @@ -536,7 +541,7 @@ std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
};
support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate);
population.swap(next_population);
LOG(INFO) << "Evolve iter #" << iter << " done";
LOG(INFO) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures();
}
// Return the best states from the heap, sorting from higher score to lower ones
std::sort(heap.heap.begin(), heap.heap.end());
Expand Down
7 changes: 4 additions & 3 deletions src/meta_schedule/search_strategy/replay_trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,16 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure
ICHECK_LT(st, ed);
std::vector<TRandState> per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_);
Array<MeasureCandidate> per_task_result(ed - st, MeasureCandidate{nullptr});
auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id,
int task_id) -> void {
ThreadedTraceApply pp(self->postprocs_);
auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id,
int task_id) -> void {
TRandState& rand_state = per_thread_rand_state[thread_id];
IRModule mod = self->per_thread_mod_[thread_id];
for (;;) {
int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size());
tir::Trace trace = design_spaces[design_space_index];
tir::Trace new_trace = tir::Trace(trace->insts, {});
if (Optional<tir::Schedule> sch = ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) {
if (Optional<tir::Schedule> sch = pp.Apply(mod, new_trace, &rand_state)) {
per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_));
break;
}
Expand Down
67 changes: 44 additions & 23 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/support/parallel_for.h>

#include <memory>
#include <sstream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -277,31 +278,51 @@ inline int GetTargetNumCores(const Target& target) {
return num_cores;
}

/*!
* \brief Apply the trace and postprocessors to an IRModule
* \param mod The IRModule to be applied
* \param trace The trace to apply to the IRModule
* \param rand_state The random seed
* \param postprocs The postprocessors to apply to the IRModule
* \return The schedule created, or NullOpt if any postprocessor fails
*/
inline Optional<tir::Schedule> ApplyTrace(const IRModule& mod, const tir::Trace& trace,
TRandState* rand_state,
const Array<Postproc>& postprocs) {
tir::Schedule sch =
tir::Schedule::Traced(mod,
/*rand_state=*/ForkSeed(rand_state),
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();
for (const Postproc& proc : postprocs) {
if (!proc->Apply(sch)) {
return NullOpt;
struct ThreadedTraceApply {
const Array<Postproc>& postprocs;
std::vector<std::unique_ptr<std::atomic<int>>> fail_counter;

explicit ThreadedTraceApply(const Array<Postproc>& postprocs)
: postprocs(postprocs), fail_counter(postprocs.size()) {
for (std::unique_ptr<std::atomic<int>>& p : fail_counter) {
p = std::make_unique<std::atomic<int>>(0);
}
}
return sch;
}

/*!
* \brief Apply the trace and postprocessors to an IRModule
* \param mod The IRModule to be applied
* \param trace The trace to apply to the IRModule
* \param rand_state The random seed
* \return The schedule created, or NullOpt if any postprocessor fails
*/
Optional<tir::Schedule> Apply(const IRModule& mod, const tir::Trace& trace,
TRandState* rand_state) {
tir::Schedule sch =
tir::Schedule::Traced(mod,
/*rand_state=*/ForkSeed(rand_state),
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();
for (int i = 0, n = postprocs.size(); i < n; ++i) {
if (!postprocs[i]->Apply(sch)) {
++*fail_counter[i];
return NullOpt;
}
}
return sch;
}

std::string SummarizeFailures() const {
std::ostringstream os;
for (int i = 0, n = postprocs.size(); i < n; ++i) {
os << "Postproc #" << i << " [" << postprocs[i] //
<< "]: " << *fail_counter[i] << " failure(s)\n";
}
return os.str();
}
};

} // namespace meta_schedule
} // namespace tvm
Expand Down
3 changes: 2 additions & 1 deletion tests/python/meta_schedule/test_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _parse_args():
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
if parsed.target.attrs["mtriple"] == "aarch64-linux-gnu":
if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
parsed.alloc_repeat = 3
else:
parsed.alloc_repeat = 1
Expand Down Expand Up @@ -92,6 +92,7 @@ def main():
config=ms.EvolutionarySearchConfig(
num_trials_per_iter=64,
num_trials_total=ARGS.num_trials,
init_max_fail_count=1024,
),
runner=runner,
task_name=ARGS.workload,
Expand Down

0 comments on commit 526b92d

Please sign in to comment.