Skip to content

Commit

Permalink
Minor fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Mar 2, 2022
1 parent 2d667c6 commit 166834c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
6 changes: 2 additions & 4 deletions python/tvm/meta_schedule/task_scheduler/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@


@register_func("meta_schedule.task_scheduler.derive_similarity_tag")
def derive_similarity_tag(mod: IRModule, log_base: float = 1.618) -> str:
def derive_similarity_tag(mod: IRModule) -> str:
"""Get the tags for smilarity group creation
Parameters
----------
mod : IRModule
The input workload.
log_base : float
The log base to normalize the flop count. Default natural (1.618).
Return
------
Expand All @@ -57,7 +55,7 @@ def derive_similarity_tag(mod: IRModule, log_base: float = 1.618) -> str:
ret += mod[var].attrs.meta_scheduler_task_scheduler_tag + "_"
if ret:
flop_count = _ffi_api.TaskSchedulerFlopCount(mod) # type: ignore # pylint: disable=no-member
ret += "%d" % int(math.log(flop_count + 1, log_base))
ret += "%d" % int(math.log(flop_count + 1, 1.618))
return ret


Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/task_scheduler/gradient_based.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ class GradientBasedNode final : public TaskSchedulerNode {
for (int i : task_groups[group_id]) {
best_flops = std::max(best_flops, task_flop_counts[i] / task_best_latencies[i]);
max_cnt[0] = task_cnts[i];
std::sort(max_cnt, max_cnt + 3); // place the 2nd largest to #1
std::sort(max_cnt, max_cnt + 3); // place the 2nd largest to middle position
}
double cur_flops = task_flop_counts[task_id] / task_best_latencies[task_id];
// if we tune a task for many times but it still cannot achieve
// a similar speed to the fastest one in its group, this means this task
// is actually not similar to other tasks in its group.
// So we will remove it from its original group.
// So we will move it from the current group to a standalone group.

if (cur_flops < best_flops / beta && task_cnts[task_id] > 5 + max_cnt[1]) {
task_groups[group_id].erase(task_id);
Expand Down

0 comments on commit 166834c

Please sign in to comment.