diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 03585ea40c03..5b7e886f073c 100755 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -210,6 +210,8 @@ Array ProgramMeasurerNode::Measure(const SearchTask& task, const SearchPolicy& policy, const Array& inputs, int batch_size) { + auto t_begin = std::chrono::high_resolution_clock::now(); + Array results; results.reserve(inputs.size()); @@ -220,7 +222,7 @@ Array ProgramMeasurerNode::Measure(const SearchTask& task, int old_verbosity = verbose; - StdCout(verbose) << "Get " << inputs.size() << " programs to measure." << std::endl; + StdCout(verbose) << "Get " << inputs.size() << " programs to measure:" << std::endl; for (size_t i = 0; i < inputs.size(); i += batch_size) { Array input_batch(inputs.begin() + i, @@ -280,6 +282,8 @@ Array ProgramMeasurerNode::Measure(const SearchTask& task, } } + PrintTimeElapsed(t_begin, "measurement", verbose); + return results; } diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 4c3e8ac5593d..07d2837ab994 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -162,17 +162,13 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure Array results; while (ct < n_trials) { if (!inputs.empty()) { - auto tic_begin = std::chrono::high_resolution_clock::now(); + auto t_begin = std::chrono::high_resolution_clock::now(); // Retrain the cost model before the next search round PrintTitle("Train cost model", verbose); program_cost_model->Update(inputs, results); - double duration = std::chrono::duration_cast>( - std::chrono::high_resolution_clock::now() - tic_begin) - .count(); - StdCout(verbose) << "Time elapsed: " << std::fixed << std::setprecision(2) << duration - << " s" << std::endl; + PrintTimeElapsed(t_begin, "training", verbose); } // Search one round to get promising states @@ -258,17 +254,13 @@ std::pair, Array> SketchPolicyNode::ContinueS measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); } - auto tic_begin = std::chrono::high_resolution_clock::now(); + auto t_begin = std::chrono::high_resolution_clock::now(); // Update the cost model PrintTitle("Train cost model", verbose); program_cost_model->Update(inputs, results); - double duration = std::chrono::duration_cast>( - std::chrono::high_resolution_clock::now() - tic_begin) - .count(); - StdCout(verbose) << "Time elapsed: " << std::fixed << std::setprecision(2) << duration << " s" - << std::endl; + PrintTimeElapsed(t_begin, "training", verbose); return std::make_pair(std::move(inputs), std::move(results)); } diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 88c649c6f919..bc29a3761129 100755 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -253,6 +254,16 @@ inline std::string Chars(const char& str, int times) { return ret.str(); } +/*! \brief Print the time elapsed */ +inline void PrintTimeElapsed(std::chrono::time_point t_begin, + const std::string& info, int verbose) { + double duration = std::chrono::duration_cast>( + std::chrono::high_resolution_clock::now() - t_begin) + .count(); + StdCout(verbose) << "Time elapsed for " << info << ": " << std::fixed << std::setprecision(2) + << duration << " s" << std::endl; +} + /*! * \brief Parse shape and axis names from layout string */