diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 884741bd08cc..de11fc1b5b11 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -22,7 +22,7 @@ L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor Programs for Deep Learning." (OSDI 2020). """ - +import os import time import math import logging @@ -168,6 +168,9 @@ class TaskScheduler: The parameter used for 'gradient' strategy backward_window_size: int = 3 The parameter used for 'gradient' strategy + callbacks: Optional[List[TaskSchedulerCallback]] + The task scheduler callbacks that will be called before and after tuning a task. + If None, then PrintTableInfo callback will be used. """ def __init__( @@ -182,6 +185,7 @@ def __init__( beta: float = 2, gamma: float = 0.5, backward_window_size: int = 3, + callbacks=None, ): self.tasks = tasks if objective_func: # use custom objective function @@ -199,6 +203,7 @@ def __init__( self.beta = beta self.gamma = gamma self.backward_window_size = backward_window_size + self.callbacks = callbacks if callbacks is not None else [PrintTableInfo()] assert len(self.tasks) != 0, "No tasks" assert self.strategy in ["round-robin", "gradient"] @@ -374,39 +379,12 @@ def tune(self, tune_option, search_policy="default"): ) break - def _print_table_info(self, next_task_idx): - # table header - _ffi_api.PrintTitle("Task Scheduler") - print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |") - print("-------------------------------------------------") - - # content - for i in range(len(self.tasks)): - id_str = "%d" % i - latency_str = "%.3f" % (1e3 * self.best_costs[i]) if self.best_costs[i] < 1e9 else "-" - speed_str = ( - "%.2f" % (self.tasks[i].compute_dag.flop_ct / self.best_costs[i] / 1e9) - if self.best_costs[i] < 1e9 - else "-" - ) - trials_str = "%d" % (self.task_cts[i] * self.num_measures_per_round) - print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str)) - print("-------------------------------------------------") - - # overall info - if all(cost < 1e9 for cost in self.best_costs): - total_latency_str = "%.3f" % (self.cur_score * 1e3) - else: - total_latency_str = "-" - print( - "Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t" - % (total_latency_str, self.ct, time.time() - self.tic, next_task_idx) - ) - def _tune_task(self, task_idx): """Tune the select task for one round""" - if self.tune_option.verbose >= 1: - self._print_table_info(task_idx) + + # Run pre-tune callbacks + for callback in self.callbacks: + callback.pre_tune(self, task_idx) measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round( self.num_measures_per_round, self.measurer @@ -426,6 +404,10 @@ def _tune_task(self, task_idx): self.ct += len(measure_inputs) self.cur_score = self._compute_score(self.best_costs) + # Run post-tune callbacks + for callback in self.callbacks: + callback.post_tune(self, task_idx) + def _compute_score(self, costs): """compute the objective function""" return self.objective_func(costs) @@ -478,3 +460,109 @@ def _restore_status(self, log_file, num_measures_per_round): self.cur_score = self._compute_score(self.best_costs) logger.info("TaskScheduler: Loaded %d measurement records from %s", total_ct + 1, log_file) + + +class TaskSchedulerCallback: + """The base class of task scheduler callback functions. """ + + def pre_tune(self, task_scheduler, task_id): + """The callback before tuning each task. + + Parameters + ---------- + task_scheduler: TaskScheduler + The task scheduler. + task_id: int + The task ID going to be tuned. + """ + # Do nothing by default + + def post_tune(self, task_scheduler, task_id): + """The callback after tuning each task. + + Parameters + ---------- + task_scheduler: TaskScheduler + The task scheduler. + task_id: int + The task ID be tuned. + """ + # Do nothing by default + + +class PrintTableInfo(TaskSchedulerCallback): + """The callback that prints a table of current progress.""" + + def pre_tune(self, task_scheduler, task_id): + if task_scheduler.tune_option.verbose < 1: + return + + _ffi_api.PrintTitle("Task Scheduler") + print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |") + print("-------------------------------------------------") + + # content + for i in range(len(task_scheduler.tasks)): + id_str = "%d" % i + latency_str = ( + "%.3f" % (1e3 * task_scheduler.best_costs[i]) + if task_scheduler.best_costs[i] < 1e9 + else "-" + ) + speed_str = ( + "%.2f" + % (task_scheduler.tasks[i].compute_dag.flop_ct / task_scheduler.best_costs[i] / 1e9) + if task_scheduler.best_costs[i] < 1e9 + else "-" + ) + trials_str = "%d" % (task_scheduler.task_cts[i] * task_scheduler.num_measures_per_round) + print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str)) + print("-------------------------------------------------") + + # overall info + if all(cost < 1e9 for cost in task_scheduler.best_costs): + total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3) + else: + total_latency_str = "-" + print( + "Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t" + % ( + total_latency_str, + task_scheduler.ct, + time.time() - task_scheduler.tic, + task_id, + ) + ) + + +class LogEstimatedLatency(TaskSchedulerCallback): + """Log the estimated latency to the file after tuning a task. + + Parameters + ---------- + log_file: str + The log file path. + """ + + def __init__(self, log_file): + if os.path.exists(log_file): # Remove existing log + os.remove(log_file) + + self.log_file = log_file + + def post_tune(self, task_scheduler, task_id): + if all(cost < 1e9 for cost in task_scheduler.best_costs): + total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3) + else: + total_latency_str = "N/A" + + with open(self.log_file, "a") as filep: + filep.write( + "ElapsedTime(s)\t%.0f\tEstimatedLatency(ms)\t%s\tTrials\t%d\n" + % ( + time.time() - task_scheduler.tic, + total_latency_str, + task_scheduler.ct, + ) + ) + filep.flush()