Skip to content

Commit

Permalink
fixed lint
Browse files Browse the repository at this point in the history
  • Loading branch information
canesche committed Feb 11, 2024
1 parent 34cba65 commit 94967e4
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 74 deletions.
1 change: 0 additions & 1 deletion python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import logging
from collections import OrderedDict
import numpy as np
from collections import OrderedDict

import tvm._ffi
from tvm.runtime import Object, ndarray
Expand Down
78 changes: 51 additions & 27 deletions python/tvm/auto_scheduler/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Space:
"""

def __init__(self, cfg, task):
self.jfile, self.cfg = cfg, cfg["i"][1][1]
self.cfg = deepcopy(cfg)
self.total_dims, self.dims, self.task = 0, [], task
self.config_space = {}
self.create_space()
Expand All @@ -61,15 +61,18 @@ def create_space(self):
"""Create the space using Ansor's space"""
sp_space = [4, 8, 16, 24, 32, 48, 64]
pr_space = [64, 128, 256, 512]
for i in range(len(self.cfg)):
f = self.cfg[i]
if f[0] == "SP" and f[3] != 1:
for j in range(len(f[4])):
self.config_space[f"{f[0]}_{i}_{j}"] = self.add_space(sp_space, [f[4][j]], f[3])
elif f[0] == "PR":
start_value = int(f[3].split("$")[-1])
config = self.cfg["i"][1][1]
for i in range(len(config)):
opt = config[i]
if opt[0] == "SP" and opt[3] != 1:
for j in range(len(opt[4])):
self.config_space[f"{opt[0]}_{i}_{j}"] = self.add_space(
sp_space, [opt[4][j]], opt[3]
)
elif opt[0] == "PR":
start_value = int(opt[3].split("$")[-1])
if start_value != 0:
self.config_space[f"{f[0]}_{i}"] = [
self.config_space[f"{opt[0]}_{i}"] = [
f"auto_unroll_max_step${v}" for v in self.add_space(pr_space, [start_value])
]
self.dims = []
Expand All @@ -82,32 +85,53 @@ def create_space(self):

def apply_opt(self, vals):
"""Apply the space using Ansor's space"""
jfile = deepcopy(self.jfile)
cfg = jfile["i"][1][1]
index = 0
for i in range(len(cfg)):
f = cfg[i]
if f[0] == "SP" and f[3] != 1:
new_f = []
for j in range(len(f[4])):
new_f.append(self.get_value(f"{f[0]}_{i}_{j}", vals[index]))
index, config = 0, self.cfg["i"][1][1]
for i in range(len(config)):
opt = config[i]
if opt[0] == "SP" and opt[3] != 1:
new_value = []
for j in range(len(opt[4])):
new_value.append(self.get_value(f"{opt[0]}_{i}_{j}", vals[index]))
index += 1
cfg[i] = ["SP", f[1], f[2], f[3], new_f, f[5]]
elif f[0] == "PR":
if f[3] != "auto_unroll_max_step$0":
cfg[i] = ["PR", f[1], f[2], self.get_value(f"{f[0]}_{i}", vals[index])]
config[i][4] = new_value
elif opt[0] == "PR":
if opt[3] != "auto_unroll_max_step$0":
config[i] = ["PR", opt[1], opt[2], self.get_value(f"{opt[0]}_{i}", vals[index])]
index += 1
return jfile

def run(self, log, final_log):
return self.cfg

def run(
self,
log,
final_log,
timeout=20,
verbose=0,
number=3,
repeat=3,
min_repeat_ms=0,
cooldown_interval=0,
cache=False,
dev=0,
):
"""Execute a log file and save"""
readlines, _ = tvm.auto_scheduler.RecordReader(log).read_lines()
inputs, results = [], []
for i in range(len(readlines)):
state = self.task.compute_dag.infer_bound_from_state(readlines[i].state)
inp = [tvm.auto_scheduler.MeasureInput(self.task, state)]
build_res = local_builder_build(inp, 20, os.cpu_count(), "default", 0)
res = local_run(inp, build_res, 20, 3, 3, 0, 0, False, 0, 0)
build_res = local_builder_build(inp, timeout, os.cpu_count(), "default", verbose)
res = local_run(
inp,
build_res,
timeout,
number,
repeat,
min_repeat_ms,
cooldown_interval,
cache,
verbose,
dev,
)
tvm.auto_scheduler._ffi_api.SaveRecords(final_log, inp, res)
inputs.append(inp[0])
results.append(MeasureResultSpace(res))
Expand Down
50 changes: 25 additions & 25 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def derive_similarity_tag(dag, log_base=1.618):
if tag:
ret += op.attrs["auto_scheduler_task_scheduler_tag"] + "_"
if ret:
ret += "%d" % int(math.log(dag.flop_ct + 1, log_base))
ret += f"{int(math.log(dag.flop_ct + 1, log_base))}"
return ret


Expand Down Expand Up @@ -539,7 +539,7 @@ def _restore_status(self, log_file, num_measures_per_round):

self.cur_score = self._compute_score(self.best_costs)

logger.info("TaskScheduler: Loaded %d measurement records from %s", total_ct + 1, log_file)
logger.info(f"TaskScheduler: Loaded {total_ct + 1} measurement records from {log_file}")


class TaskSchedulerCallback:
Expand Down Expand Up @@ -592,21 +592,21 @@ def pre_tune(self, task_scheduler, task_id):
for i in range(len(task_scheduler.tasks)):
id_str = f"{i}"
latency_str = (
"%.3f" % (1e3 * task_scheduler.best_costs[i])
f"{1e3 * task_scheduler.best_costs[i]:.3f}"
if task_scheduler.best_costs[i] < 1e9
else "-"
)
task_desc = task_scheduler.tasks[i].desc
best_cost = task_scheduler.best_costs[i]
speed_str = (
"%.2f"
% (task_scheduler.tasks[i].compute_dag.flop_ct / task_scheduler.best_costs[i] / 1e9)
f"{task_scheduler.tasks[i].compute_dag.flop_ct / best_cost / 1e9:.2f}"
if task_scheduler.best_costs[i] < 1e9
else "-"
)
trials_str = "%d" % (task_scheduler.task_cts[i] * task_scheduler.num_measures_per_round)
trials_str = f"{(task_scheduler.task_cts[i] * task_scheduler.num_measures_per_round)}"
print(
"| %4s | %61s | %12s | % 14s | %6s |"
% (id_str, task_desc, latency_str, speed_str, trials_str)
f"| {id_str:4s} | {task_desc:61s} | {latency_str:12s} | "
f"{speed_str:14s} | {trials_str:6s} |"
)
print(
"----------------------------------------------------------------"
Expand All @@ -615,12 +615,12 @@ def pre_tune(self, task_scheduler, task_id):

# overall info
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
total_latency_str = f"{task_scheduler.cur_score * 1e3:.3f}"
else:
total_latency_str = "-"
print(
"Estimated total latency: %s ms\tTrials: %d\tUsed time : %.0f s\tNext ID: %d\t"
% (total_latency_str, task_scheduler.ct, time.time() - task_scheduler.tic, task_id)
f"Estimated total latency: {total_latency_str} ms\tTrials: {task_scheduler.ct}\t"
f"Used time : {time.time() - task_scheduler.tic:.0f} s\tNext ID: {task_id}\t"
)


Expand All @@ -641,19 +641,19 @@ def __init__(self, log_file):

def post_tune(self, task_scheduler, task_id):
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
total_latency_str = f"{task_scheduler.cur_score * 1e3:.3f}"
else:
total_latency_str = "N/A"

with open(self.log_file, "a") as filep:
filep.write(
"ElapsedTime(s)\t%.0f\tEstimatedLatency(ms)\t%s\tTrials\t%d\n"
% (time.time() - task_scheduler.tic, total_latency_str, task_scheduler.ct)
f"ElapsedTime(s)\t{time.time() - task_scheduler.tic:.0f}\tEstimatedLatency(ms)\t"
f"{total_latency_str}\tTrials\t{task_scheduler.ct}\n"
)
filep.flush()


def droplet_exploitation(log_file, target="x86", verbose=True):
def droplet_exploitation(log_file, target="llvm", verbose=True):
"""optimization of the model after execution of the Ansor method, using
Droplet algorithm."""
cfg = get_multilayers(log_file)
Expand All @@ -675,24 +675,24 @@ def droplet_exploitation(log_file, target="x86", verbose=True):
best_time, time_total, best_cfg = get_time(log)
time_droplet += time_total

droplet_avg_time = np.mean(best_time)
ansor_avg_time = np.mean(ansor_time)
speedup = ansor_avg_time / droplet_avg_time

# Append the best solution in the same Ansor's log
# solutions that are invalid or same are not saved
if np.mean(best_time) != 1e10:
write_file([best_cfg], log_file, "a")
if verbose:
print(
"%d, %.6f, %.6f, %.2f, %.2f"
% (
layer,
np.mean(best_time),
np.mean(ansor_time),
time_total,
np.mean(ansor_time) / np.mean(best_time),
)
f"{layer}, {droplet_avg_time:.6f}, {ansor_avg_time:.6f},"
f" {time_total:.2f}, {speedup:.2f}"
)

if verbose:
time_total = time_total_ansor + time_droplet
print(
"Time Ansor (s): %.2f, Time Droplet (s): %.2f, Time Total (s): %.2f"
% (time_total_ansor, time_droplet, time_total_ansor + time_droplet)
f"Time Ansor (s): {time_total_ansor:.2f},"
f"Time Droplet (s): {time_droplet:.2f},"
f"Time Total (s): {time_total:.2f}"
)
34 changes: 16 additions & 18 deletions python/tvm/auto_scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,14 +424,13 @@ def get_multilayers(log):
A dictionary with a tuple
"""
hash_map = dict()
file = open(log, "r")
for line in file.readlines():
data = json.loads(line)
if "i" in data:
res, key = data["r"][0], data["i"][0][0]
if key not in hash_map or np.mean(hash_map[key][0]) > np.mean(res):
hash_map[key] = (res, data)
file.close()
with open(log, "r", encoding="utf-8") as log_file:
for line in log_file.readlines():
data = json.loads(line)
if "i" in data:
res, key = data["r"][0], data["i"][0][0]
if key not in hash_map or np.mean(hash_map[key][0]) > np.mean(res):
hash_map[key] = (res, data)
return hash_map


Expand Down Expand Up @@ -464,21 +463,20 @@ def get_time(log):
Parameters
----------
log: str
The input log path
The input log path with the Ansor parameter
Returns
-------
ret: Union[float, float, dict]
Returns the best time, total time, and data
"""
time_total, best_time, best_cfg = 0, 1e10, {}
f = open(log, "r")
for line in f.readlines():
data = json.loads(line)
if "r" in data:
res = data["r"][0]
time_total += data["r"][2]
if np.mean(res) < np.mean(best_time):
best_time, best_cfg = res, data
f.close()
with open(log, "r", encoding="utf-8") as log_file:
for line in log_file.readlines():
data = json.loads(line)
if "r" in data:
res = data["r"][0]
time_total += data["r"][2]
if np.mean(res) < np.mean(best_time):
best_time, best_cfg = res, data
return best_time, time_total, best_cfg
8 changes: 5 additions & 3 deletions tests/python/auto_scheduler/test_auto_scheduler_droplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from tvm.testing.auto_scheduler import matmul_auto_scheduler_test


@tvm.testing.requires_llvm
def test_task_scheduler_gradient_droplet():
tasks = []
Expand Down Expand Up @@ -63,7 +64,7 @@ def objective_func(costs):
task_scheduler.tune(tune_option, search_policy="sketch.random")

# Use the droplet algorithm to optimize the kernel
auto_scheduler.task_scheduler.droplet_exploitation(log_file)
auto_scheduler.task_scheduler.droplet_exploitation(log_file, tasks[0].target)

# Check the allocation results
counters = {}
Expand All @@ -73,8 +74,9 @@ def objective_func(costs):
for inp, _ in auto_scheduler.load_records(log_file):
counters[inp.task.workload_key] += 1

assert counters[tasks[0].workload_key] == n_trials - 1
assert counters[tasks[1].workload_key] == 1
# droplet adds an optimized solution at the end
assert counters[tasks[0].workload_key] == n_trials
assert counters[tasks[1].workload_key] == 2
del measure_ctx


Expand Down

0 comments on commit 94967e4

Please sign in to comment.