Skip to content

Commit

Permalink
[AutoScheduler] Print the time used for measurement (apache#6972)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Print the time used for measurement

* address comments
  • Loading branch information
merrymercy authored and trevor-m committed Dec 4, 2020
1 parent 4e2b7ce commit a76292b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
6 changes: 5 additions & 1 deletion src/auto_scheduler/measure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,
const SearchPolicy& policy,
const Array<MeasureInput>& inputs,
int batch_size) {
auto t_begin = std::chrono::high_resolution_clock::now();

Array<MeasureResult> results;
results.reserve(inputs.size());

Expand All @@ -220,7 +222,7 @@ Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,

int old_verbosity = verbose;

StdCout(verbose) << "Get " << inputs.size() << " programs to measure." << std::endl;
StdCout(verbose) << "Get " << inputs.size() << " programs to measure:" << std::endl;

for (size_t i = 0; i < inputs.size(); i += batch_size) {
Array<MeasureInput> input_batch(inputs.begin() + i,
Expand Down Expand Up @@ -280,6 +282,8 @@ Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,
}
}

PrintTimeElapsed(t_begin, "measurement", verbose);

return results;
}

Expand Down
16 changes: 4 additions & 12 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,13 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
Array<MeasureResult> results;
while (ct < n_trials) {
if (!inputs.empty()) {
auto tic_begin = std::chrono::high_resolution_clock::now();
auto t_begin = std::chrono::high_resolution_clock::now();

// Retrain the cost model before the next search round
PrintTitle("Train cost model", verbose);
program_cost_model->Update(inputs, results);

double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
.count();
StdCout(verbose) << "Time elapsed: " << std::fixed << std::setprecision(2) << duration
<< " s" << std::endl;
PrintTimeElapsed(t_begin, "training", verbose);
}

// Search one round to get promising states
Expand Down Expand Up @@ -258,17 +254,13 @@ std::pair<Array<MeasureInput>, Array<MeasureResult>> SketchPolicyNode::ContinueS
measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs));
}

auto tic_begin = std::chrono::high_resolution_clock::now();
auto t_begin = std::chrono::high_resolution_clock::now();

// Update the cost model
PrintTitle("Train cost model", verbose);
program_cost_model->Update(inputs, results);

double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
.count();
StdCout(verbose) << "Time elapsed: " << std::fixed << std::setprecision(2) << duration << " s"
<< std::endl;
PrintTimeElapsed(t_begin, "training", verbose);

return std::make_pair(std::move(inputs), std::move(results));
}
Expand Down
11 changes: 11 additions & 0 deletions src/auto_scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <deque>
#include <exception>
#include <future>
#include <iomanip>
#include <numeric>
#include <random>
#include <string>
Expand Down Expand Up @@ -253,6 +254,16 @@ inline std::string Chars(const char& str, int times) {
return ret.str();
}

/*! \brief Print the time elapsed */
inline void PrintTimeElapsed(std::chrono::time_point<std::chrono::high_resolution_clock> t_begin,
const std::string& info, int verbose) {
double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - t_begin)
.count();
StdCout(verbose) << "Time elapsed for " << info << ": " << std::fixed << std::setprecision(2)
<< duration << " s" << std::endl;
}

/*!
* \brief Parse shape and axis names from layout string
*/
Expand Down

0 comments on commit a76292b

Please sign in to comment.