Skip to content

Commit

Permalink
[AutoScheduler] Task scheduler callbacks (#6945)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Task scheduler callbacks

* docstring

* address comments

* Delete the explaination of callback in the tutorial

* fix

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
  • Loading branch information
comaniac and merrymercy authored Nov 24, 2020
1 parent 48f9135 commit 9d71cea
Showing 1 changed file with 120 additions and 32 deletions.
152 changes: 120 additions & 32 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
Programs for Deep Learning." (OSDI 2020).
"""

import os
import time
import math
import logging
Expand Down Expand Up @@ -168,6 +168,9 @@ class TaskScheduler:
The parameter used for 'gradient' strategy
backward_window_size: int = 3
The parameter used for 'gradient' strategy
callbacks: Optional[List[TaskSchedulerCallback]]
The task scheduler callbacks that will be called before and after tuning a task.
If None, then PrintTableInfo callback will be used.
"""

def __init__(
Expand All @@ -182,6 +185,7 @@ def __init__(
beta: float = 2,
gamma: float = 0.5,
backward_window_size: int = 3,
callbacks=None,
):
self.tasks = tasks
if objective_func: # use custom objective function
Expand All @@ -199,6 +203,7 @@ def __init__(
self.beta = beta
self.gamma = gamma
self.backward_window_size = backward_window_size
self.callbacks = callbacks if callbacks is not None else [PrintTableInfo()]

assert len(self.tasks) != 0, "No tasks"
assert self.strategy in ["round-robin", "gradient"]
Expand Down Expand Up @@ -374,39 +379,12 @@ def tune(self, tune_option, search_policy="default"):
)
break

def _print_table_info(self, next_task_idx):
# table header
_ffi_api.PrintTitle("Task Scheduler")
print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |")
print("-------------------------------------------------")

# content
for i in range(len(self.tasks)):
id_str = "%d" % i
latency_str = "%.3f" % (1e3 * self.best_costs[i]) if self.best_costs[i] < 1e9 else "-"
speed_str = (
"%.2f" % (self.tasks[i].compute_dag.flop_ct / self.best_costs[i] / 1e9)
if self.best_costs[i] < 1e9
else "-"
)
trials_str = "%d" % (self.task_cts[i] * self.num_measures_per_round)
print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str))
print("-------------------------------------------------")

# overall info
if all(cost < 1e9 for cost in self.best_costs):
total_latency_str = "%.3f" % (self.cur_score * 1e3)
else:
total_latency_str = "-"
print(
"Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t"
% (total_latency_str, self.ct, time.time() - self.tic, next_task_idx)
)

def _tune_task(self, task_idx):
"""Tune the select task for one round"""
if self.tune_option.verbose >= 1:
self._print_table_info(task_idx)

# Run pre-tune callbacks
for callback in self.callbacks:
callback.pre_tune(self, task_idx)

measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round(
self.num_measures_per_round, self.measurer
Expand All @@ -426,6 +404,10 @@ def _tune_task(self, task_idx):
self.ct += len(measure_inputs)
self.cur_score = self._compute_score(self.best_costs)

# Run post-tune callbacks
for callback in self.callbacks:
callback.post_tune(self, task_idx)

def _compute_score(self, costs):
"""compute the objective function"""
return self.objective_func(costs)
Expand Down Expand Up @@ -478,3 +460,109 @@ def _restore_status(self, log_file, num_measures_per_round):
self.cur_score = self._compute_score(self.best_costs)

logger.info("TaskScheduler: Loaded %d measurement records from %s", total_ct + 1, log_file)


class TaskSchedulerCallback:
"""The base class of task scheduler callback functions. """

def pre_tune(self, task_scheduler, task_id):
"""The callback before tuning each task.
Parameters
----------
task_scheduler: TaskScheduler
The task scheduler.
task_id: int
The task ID going to be tuned.
"""
# Do nothing by default

def post_tune(self, task_scheduler, task_id):
"""The callback after tuning each task.
Parameters
----------
task_scheduler: TaskScheduler
The task scheduler.
task_id: int
The task ID be tuned.
"""
# Do nothing by default


class PrintTableInfo(TaskSchedulerCallback):
"""The callback that prints a table of current progress."""

def pre_tune(self, task_scheduler, task_id):
if task_scheduler.tune_option.verbose < 1:
return

_ffi_api.PrintTitle("Task Scheduler")
print("| ID | Latency (ms) | Speed (GFLOPS) | Trials |")
print("-------------------------------------------------")

# content
for i in range(len(task_scheduler.tasks)):
id_str = "%d" % i
latency_str = (
"%.3f" % (1e3 * task_scheduler.best_costs[i])
if task_scheduler.best_costs[i] < 1e9
else "-"
)
speed_str = (
"%.2f"
% (task_scheduler.tasks[i].compute_dag.flop_ct / task_scheduler.best_costs[i] / 1e9)
if task_scheduler.best_costs[i] < 1e9
else "-"
)
trials_str = "%d" % (task_scheduler.task_cts[i] * task_scheduler.num_measures_per_round)
print("| %4s | %12s | % 14s | %6s |" % (id_str, latency_str, speed_str, trials_str))
print("-------------------------------------------------")

# overall info
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
else:
total_latency_str = "-"
print(
"Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t"
% (
total_latency_str,
task_scheduler.ct,
time.time() - task_scheduler.tic,
task_id,
)
)


class LogEstimatedLatency(TaskSchedulerCallback):
"""Log the estimated latency to the file after tuning a task.
Parameters
----------
log_file: str
The log file path.
"""

def __init__(self, log_file):
if os.path.exists(log_file): # Remove existing log
os.remove(log_file)

self.log_file = log_file

def post_tune(self, task_scheduler, task_id):
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
else:
total_latency_str = "N/A"

with open(self.log_file, "a") as filep:
filep.write(
"ElapsedTime(s)\t%.0f\tEstimatedLatency(ms)\t%s\tTrials\t%d\n"
% (
time.time() - task_scheduler.tic,
total_latency_str,
task_scheduler.ct,
)
)
filep.flush()

0 comments on commit 9d71cea

Please sign in to comment.