Skip to content

Commit

Permalink
[AutoScheduler] Enable schedule sharing in dispatch context (apache#7344
Browse files Browse the repository at this point in the history
)

* [AutoScheduler] Enable schedule sharing in dispatch context

* Update python/tvm/auto_scheduler/dispatcher.py
  • Loading branch information
comaniac authored and trevor-m committed Mar 2, 2021
1 parent 94d12fd commit ee8b49c
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 105 deletions.
135 changes: 102 additions & 33 deletions python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from tvm.tir.expr import FloatImm
from .measure_record import load_records
from .utils import calc_workload_dis_factor, decode_workload_key

logger = logging.getLogger("auto_scheduler")

Expand Down Expand Up @@ -126,18 +127,53 @@ class ApplyHistoryBest(DispatchContext):
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Otherwise, it is an iterator.
n_lines: Optional[int]
if it is not None, only load the first `n_lines` lines of log
if it is not None, only load the first `n_lines` lines of log.
include_compatible: bool
When set to True, compatible records will also be considered.
"""

def __init__(self, records, n_lines=None):
def __init__(self, records, n_lines=None, include_compatible=False):
super(ApplyHistoryBest, self).__init__()
self.include_compatible = include_compatible

# Dict[str (target key),
# Dict[str (workload hash),
# Dict[tuple (workload args), tuple (State, cost)]]]
self.best_by_targetkey = {}
self.best_by_model = {}
self._best_user_defined = {}

self.load(records, n_lines)

@staticmethod
def get_workload_entry(best_records, target_key, workload_key):
"""Get the entry of the target key and workload key hash in the given best record map.
Parameters
----------
best_records: Dict[str, Dict[str, Dict[str, Any]]]
The best record map.
target_key: str
The first key to the best_records.
workload_key: str
The workload key that can be decoded to workload hash and args.
Returns
-------
entry: Dict[str, Any]
The entry in best_records with target key and workload hash.
workload_hash: str
The workload hash decoded from workload_key.
workload_args: Tuple[Any, ...]
The hashable tuple of workload args decoded from workload_key.
"""
workload_hash, workload_args = decode_workload_key(workload_key)
if target_key not in best_records:
best_records[target_key] = {}
if workload_hash not in best_records[target_key]:
best_records[target_key][workload_hash] = {}
return best_records[target_key][workload_hash], workload_hash, workload_args

def load(self, records, n_lines=None):
"""Load records to this dispatch context
Expand Down Expand Up @@ -171,29 +207,32 @@ def load(self, records, n_lines=None):
if res.error_no != 0:
continue

costs = [x.value for x in res.costs if isinstance(x, FloatImm)]
cost = np.mean(costs)

# use target keys in tvm target system as key to build best map
for k in inp.task.target.keys:
key = (k, inp.task.workload_key)
if key not in best_by_targetkey:
best_by_targetkey[key] = (inp, res)
entry, _, workload_args = self.get_workload_entry(
best_by_targetkey, k, inp.task.workload_key
)
if workload_args not in entry:
entry[workload_args] = (inp.state, cost)
else:
_, other_res = best_by_targetkey[key]
other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)]
costs = [x.value for x in res.costs if isinstance(x, FloatImm)]
if np.mean(other_costs) > np.mean(costs):
best_by_targetkey[key] = (inp, res)
_, other_cost = entry[workload_args]
if other_cost > cost:
entry[workload_args] = (inp.state, cost)

# use model as key to build best map
key = (inp.task.target.model, inp.task.workload_key)
if key not in best_by_model:
entry, _, workload_args = self.get_workload_entry(
best_by_model, inp.task.target.model, inp.task.workload_key
)
if workload_args not in entry:
if inp.task.target.model != "unknown":
best_by_model[key] = (inp, res)
entry[workload_args] = (inp.state, cost)
else:
_, other_res = best_by_model[key]
other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)]
costs = [x.value for x in res.costs if isinstance(x, FloatImm)]
if np.mean(other_costs) > np.mean(costs):
best_by_model[key] = (inp, res)
_, other_cost = entry[workload_args]
if other_cost > cost:
entry[workload_args] = (inp.state, cost)

logger.debug("Finish loading %d records", counter)

Expand All @@ -205,31 +244,61 @@ def _query_inside(self, target, workload_key):
" above the dispatcher call. So does other target. "
)

def match_record(best_records, target_key, workload_key):
"""The helper function to match the record in the given map
and return the matched state, or None if no match.
"""
ret = None

entry, workload_hash, workload_args = self.get_workload_entry(
best_records, target_key, workload_key
)
if workload_args in entry:
ret = entry[workload_args][0]
elif self.include_compatible:
best_cost = float("inf")
for args, val in entry.items():
dis_f = calc_workload_dis_factor(
(workload_hash, workload_args), (workload_hash, args)
)
if dis_f == float("inf"):
continue

state, cost = val
cost *= dis_f
if ret is None or cost < best_cost:
best_cost = cost
ret = state
return ret

# first try matching by model
key = (target.model, workload_key)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_model:
return self.best_by_model[key][0].state
ret = match_record(self._best_user_defined, target.model, workload_key)
if ret is not None:
return ret
ret = match_record(self.best_by_model, target.model, workload_key)
if ret is not None:
return ret

# then try matching by target key
for k in target.keys:
key = (k, workload_key)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].state
ret = match_record(self._best_user_defined, k, workload_key)
if ret is not None:
return ret
ret = match_record(self.best_by_targetkey, k, workload_key)
if ret is not None:
return ret

return None

def update(self, target, workload_key, state):
model = target.model
key = (model, workload_key)
self._best_user_defined[key] = state
entry, _, workload_args = self.get_workload_entry(
self._best_user_defined, target.model, workload_key
)
entry[workload_args] = (state, 1)

for k in target.keys:
key = (k, workload_key)
self._best_user_defined[key] = state
entry, _, _ = self.get_workload_entry(self._best_user_defined, k, workload_key)
entry[workload_args] = (state, 1)


class FallbackContext(DispatchContext):
Expand Down
65 changes: 4 additions & 61 deletions python/tvm/auto_scheduler/measure_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tvm._ffi
from tvm.runtime import Object
from .measure import MeasureErrorNo, MeasureCallback
from .utils import decode_workload_key
from .utils import calc_workload_dis_factor, decode_workload_key
from . import _ffi_api

logger = logging.getLogger("auto_scheduler")
Expand Down Expand Up @@ -130,65 +130,6 @@ def __iter__(self):
yield ret[0], ret[1] # (input, result)


def calc_workload_dis_factor(target_workload_key, workload_key):
"""Calculate the distance factor of the workload to the target workload.
If two workloads are not compatible at all (i.e., different compute DAG or function),
then the distance factor is "inf". Otherwise, we calculate the factor by traversing
the workload arguments, which are the arguments of the compute function,
or the output shapes for the ComputeDAG. The factor is calculated by the following rules:
1. For non-zero integer values: `product(target_arg / candidate_arg)`.
2. For non-integer or zero values: "inf" if not equal else 1.
As a result, factor=1 is the optimal when two workloads are identical.
Parameters
----------
target_workload_key: str
The target workload key in JSON string.
workload_key: str
The candidate workload key in JSON string.
Returns
-------
dis_f: float
The distance factor.
"""

def flatten_list(inp):
ret = []
for elt in inp:
if isinstance(elt, list):
ret += flatten_list(elt)
else:
ret.append(elt)
return ret

target_key, target_args = decode_workload_key(target_workload_key)
target_args = flatten_list(target_args) if target_args is not None else []
key, args = decode_workload_key(workload_key)
args = flatten_list(args) if args is not None else []

# Not even the same func/DAG.
if key != target_key or len(target_args) != len(args):
return float("inf")

dis_f = 1
for target_arg, arg in zip(target_args, args):
if isinstance(target_arg, int):
if target_arg == 0 or arg == 0:
if target_arg != arg:
return float("inf")
elif target_arg % arg != 0:
return float("inf")
else:
dis_f *= target_arg / arg
elif target_arg != arg:
return float("inf")
return dis_f


def load_record_from_string(record):
"""
Load the measure record from string.
Expand Down Expand Up @@ -304,7 +245,9 @@ def load_best_record(filename, workload_key=None, target=None, include_compatibl
cost = np.mean(costs)

if workload_key is not None:
dis_f = calc_workload_dis_factor(workload_key, inp.task.workload_key)
dis_f = calc_workload_dis_factor(
decode_workload_key(workload_key), decode_workload_key(inp.task.workload_key)
)
if dis_f == float("inf"):
continue
if not include_compatible and dis_f != 1:
Expand Down
65 changes: 62 additions & 3 deletions python/tvm/auto_scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,77 @@ def decode_workload_key(workload_key):
-------
name: str
The workload function name or the DAG hash.
args: Optional[List[Any]]
The arguments of the workload, or None if the workload key format is not decodeable.
args: Optional[Tuple[Any, ...]]
The flatten arguments in a tuple, or None if the workload key format is not decodeable.
"""

def flatten_list(inp):
ret = []
for elt in inp:
if isinstance(elt, list):
ret += flatten_list(elt)
else:
ret.append(elt)
return ret

try:
key_list = json.loads(workload_key)
if isinstance(key_list, list) and len(key_list) >= 1:
return key_list[0], key_list[1:]
return key_list[0], tuple(flatten_list(key_list[1:]))
except json.decoder.JSONDecodeError:
pass
return workload_key, None


def calc_workload_dis_factor(target_workload_pair, workload_pair):
"""Calculate the distance factor of the workload to the target workload.
If two workloads are not compatible at all (i.e., different compute DAG or function),
then the distance factor is "inf". Otherwise, we calculate the factor by traversing
the workload arguments, which are the arguments of the compute function,
or the output shapes for the ComputeDAG. The factor is calculated by the following rules:
1. For non-zero integer values: `product(target_arg / candidate_arg)`.
2. For non-integer or zero values: "inf" if not equal else 1.
As a result, factor=1 is the optimal when two workloads are identical.
Parameters
----------
target_workload_pair: Tuple[str, Optional[Tuple[Any, ...]]]
The target workload pair: (hash, argument tuple).
workload_pair: Tuple[str, Optional[Tuple[Any, ...]]]
The candidate workload pair: (hash, argument tuple).
Returns
-------
dis_f: float
The distance factor.
"""
target_key, target_args = target_workload_pair
target_args = target_args if target_args is not None else []
key, args = workload_pair
args = args if args is not None else []

# Not even the same func/DAG.
if key != target_key or len(target_args) != len(args):
return float("inf")

dis_f = 1
for target_arg, arg in zip(target_args, args):
if isinstance(target_arg, int):
if target_arg == 0 or arg == 0:
if target_arg != arg:
return float("inf")
elif target_arg % arg != 0:
return float("inf")
else:
dis_f *= target_arg / arg
elif target_arg != arg:
return float("inf")
return dis_f


def get_func_name(func):
"""Get name of a function.
Expand Down
Loading

0 comments on commit ee8b49c

Please sign in to comment.