Skip to content

Commit

Permalink
[AutoScheduler] Separate shapes from DAG hash and enable schedule sha…
Browse files Browse the repository at this point in the history
…ring (#7317)

* [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing

* Update CI logs

* lint

* fix registry

* add message; fix layout rewrite mismatch

* update message

* support other formats
  • Loading branch information
comaniac authored Jan 25, 2021
1 parent 5d33491 commit e6d5318
Show file tree
Hide file tree
Showing 11 changed files with 342 additions and 149 deletions.
7 changes: 7 additions & 0 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ class ComputeDAG : public ObjectRef {
*/
String PrintStepsAsPython(const Array<Step>& transform_steps) const;

/*!
* \brief Print the compute DAG to a string. This is also used to generate the ComputeDAG hash.
* \param simple_mode Simple mode will only include the op names and brief compute.
* \return The ComputeDAG in a string.
*/
String PrintDAG(bool simple_mode = false) const;

/*!
* \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
* The states can lose complete bound information after some transform steps (e.g., compute_at).
Expand Down
35 changes: 13 additions & 22 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
""" The auto-scheduler's computational graph and related program analyses. """

import hashlib
import json

import tvm._ffi
from tvm.runtime import Object
from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
from tvm.te import ComputeOp, PlaceholderOp

from . import _ffi_api
from .loop_state import State, StateObject
Expand Down Expand Up @@ -220,32 +220,23 @@ def rewrite_layout_from_state(self, state):
state_obj = state if isinstance(state, StateObject) else state.state_object
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)

def hash_key(self):
"""Return the hash key of this compute DAG.
def workload_key(self):
"""Return the workload key of this compute DAG.
The workload key is a JSON string from a tuple of (hash-key, tensor shapes...)
Returns
-------
key: str
The hash key of this compute DAG
The workload key of this compute DAG
"""
# TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
# of ComputeDAG
str_key = ""
for op in self.ops:
t = op.output(0)
if isinstance(op, PlaceholderOp):
str_key += "placeholder,"
str_key += str(get_const_tuple(t.shape)) + ","
str_key += t.dtype + ";"
elif isinstance(op, ComputeOp):
str_key += str(t.op.body) + ","
str_key += str(get_const_tuple(t.shape)) + ","
str_key += t.dtype + ";"
else:
raise ValueError("Invalid op: " + op)

str_key = str_key.encode(encoding="utf-8")
return hashlib.md5(str_key).hexdigest()
str_dag = _ffi_api.ComputeDAGPrintDAG(self, True)
str_dag = str_dag.encode(encoding="utf-8")
hash_key = hashlib.md5(str_dag).hexdigest()

io_shapes = []
for tensor in self.tensors:
io_shapes += get_const_tuple(tensor.shape)
return json.dumps([hash_key] + io_shapes)

def __str__(self):
# pretty print
Expand Down
126 changes: 117 additions & 9 deletions python/tvm/auto_scheduler/measure_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import tvm._ffi
from tvm.runtime import Object
from .measure import MeasureErrorNo, MeasureCallback
from .utils import decode_workload_key
from . import _ffi_api

logger = logging.getLogger("auto_scheduler")
Expand Down Expand Up @@ -59,8 +60,37 @@ class RecordReader(Object):
"""

def __init__(self, filename):
# a set to prevent print duplicated message
self.messages = set()

self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename)

def check_workload_key(self, inputs):
"""Check and throw warnings for records with old format workload key.
Parameters
----------
inputs: List[MeasureInput]
The measure inputs to be checked.
Notes
-----
This checker could be deprecated in the future.
"""
for inp in inputs:
_, args = decode_workload_key(inp.task.workload_key)
if args is None:
continue
if not args:
msg = (
"MeasureInput with old format workload key %s should be updated "
"using the script from https://github.com/apache/tvm/pull/7317."
% inp.task.workload_key
)
if msg not in self.messages:
self.messages.add(msg)
logger.warning(msg)

def read_lines(self, max_lines=None, skip_lines=0):
"""Read multiple lines from the log file.
Expand Down Expand Up @@ -88,16 +118,77 @@ def read_lines(self, max_lines=None, skip_lines=0):
inputs, results = _ffi_api.RecordReaderReadLines(
self, max_lines if max_lines else -1, skip_lines
)
self.check_workload_key(inputs)
return inputs, results

def __iter__(self):
while True:
ret = _ffi_api.RecordReaderReadNext(self)
if not ret:
break
self.check_workload_key([ret[0]])
yield ret[0], ret[1] # (input, result)


def calc_workload_dis_factor(target_workload_key, workload_key):
"""Calculate the distance factor of the workload to the target workload.
If two workloads are not compatible at all (i.e., different compute DAG or function),
then the distance factor is "inf". Otherwise, we calculate the factor by traversing
the workload arguments, which are the arguments of the compute function,
or the output shapes for the ComputeDAG. The factor is calculated by the following rules:
1. For non-zero integer values: `product(target_arg / candidate_arg)`.
2. For non-integer or zero values: "inf" if not equal else 1.
As a result, factor=1 is the optimal when two workloads are identical.
Parameters
----------
target_workload_key: str
The target workload key in JSON string.
workload_key: str
The candidate workload key in JSON string.
Returns
-------
dis_f: float
The distance factor.
"""

def flatten_list(inp):
ret = []
for elt in inp:
if isinstance(elt, list):
ret += flatten_list(elt)
else:
ret.append(elt)
return ret

target_key, target_args = decode_workload_key(target_workload_key)
target_args = flatten_list(target_args) if target_args is not None else []
key, args = decode_workload_key(workload_key)
args = flatten_list(args) if args is not None else []

# Not even the same func/DAG.
if key != target_key or len(target_args) != len(args):
return float("inf")

dis_f = 1
for target_arg, arg in zip(target_args, args):
if isinstance(target_arg, int):
if target_arg == 0 or arg == 0:
if target_arg != arg:
return float("inf")
elif target_arg % arg != 0:
return float("inf")
else:
dis_f *= target_arg / arg
elif target_arg != arg:
return float("inf")
return dis_f


def load_record_from_string(record):
"""
Load the measure record from string.
Expand Down Expand Up @@ -174,7 +265,7 @@ def save_records(filename, inputs, results):
_ffi_api.SaveRecords(filename, inputs, results)


def load_best_record(filename, workload_key=None, target=None):
def load_best_record(filename, workload_key=None, target=None, include_compatible=False):
"""Return the best measurement pair form a log file. This may return none results if
there is no legal measure pair with the specified workload_key/target found from the log file.
Expand All @@ -188,6 +279,8 @@ def load_best_record(filename, workload_key=None, target=None):
target : Optional[tvm.target.Target]
The target device.
With `None`, this returns the best measure pair of all target devices.
include_compatible: bool
When set to True, all compatible records in the log file will be considered.
Returns
-------
Expand All @@ -204,13 +297,23 @@ def load_best_record(filename, workload_key=None, target=None):
for inp, res in log_reader:
if res.error_no != MeasureErrorNo.NO_ERROR:
continue
if workload_key and inp.task.workload_key != workload_key:
continue
if target and inp.task.target.kind.name != target.kind.name:
continue

costs = [v.value for v in res.costs]
cost = np.mean(costs)

if workload_key is not None:
dis_f = calc_workload_dis_factor(workload_key, inp.task.workload_key)
if dis_f == float("inf"):
continue
if not include_compatible and dis_f != 1:
continue

# Since different workloads have different FLOPS, we multiply the factor to
# eliminate this difference, which is basically the concept of throughput.
cost *= dis_f

if cost < best_cost:
best_cost = cost
best_inp = inp
Expand Down Expand Up @@ -267,12 +370,8 @@ def measure_input_str_key(inp):
logger.info("Extract %d best records from %s to %s", len(inputs), in_file, out_file)


"""
Usage:
* Distill the best entries from a large log file
e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json
"""
if __name__ == "__main__":
def main():
"""The main function for CLI."""
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["distill"], required=True)
parser.add_argument("--i", type=str, help="input file")
Expand All @@ -285,3 +384,12 @@ def measure_input_str_key(inp):
if args.mode == "distill":
args.o = args.o or args.i + ".best.json"
distill_record_file(args.i, args.o)


"""
Usage:
* Distill the best entries from a large log file
e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json
"""
if __name__ == "__main__":
main()
6 changes: 2 additions & 4 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
2. Provide auto-scheduling for all TOPI compute functions
"""

import json
import logging
import threading

Expand Down Expand Up @@ -281,7 +280,7 @@ def auto_schedule_topi(outs):
logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err))
return None

key = register_workload_tensors(dag.hash_key(), io_tensors)
key = register_workload_tensors(dag.workload_key(), io_tensors)
target = tvm.target.Target.current()

env = TracingEnvironment.current
Expand Down Expand Up @@ -310,9 +309,8 @@ def auto_schedule_topi(outs):
return None

# rewrite the layout and update the context for the new dag
dag = ComputeDAG(outs)
new_dag = dag.rewrite_layout_from_state(state)
new_key = json.dumps((new_dag.hash_key(),))
new_key = new_dag.workload_key()
if new_key != key:
dispatch_ctx.update(target, new_key, state)
else:
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,15 @@ def tune(self, tuning_options, search_policy=None):

_ffi_api.AutoSchedule(search_policy, tuning_options)

def apply_best(self, log_file, layout_rewrite_option=None):
def apply_best(self, log_file, include_compatible=False, layout_rewrite_option=None):
"""Apply the history best from a log file and return the schedule.
Parameters
----------
log_file : str
The name of the log file.
include_compatible: bool
When set to True, all compatible records in the log file will be considered.
layout_rewrite_option : Optional[LayoutRewriteOption]
The layout rewrite option.
Expand All @@ -272,7 +274,9 @@ def apply_best(self, log_file, layout_rewrite_option=None):
-------
A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
"""
inp, _ = load_best_record(log_file, self.workload_key)
inp, _ = load_best_record(
log_file, self.workload_key, include_compatible=include_compatible
)
if inp is None:
raise RuntimeError(
"Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/auto_scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
""" Common utilities for auto_scheduler. """

from typing import Hashable
import json
import multiprocessing
import multiprocessing.pool
import queue
Expand All @@ -42,6 +43,32 @@
from ..te import Tensor, placeholder


def decode_workload_key(workload_key):
"""Decode the workload key from a string to the name and arguments. The wokrload key
is expected to be a list of "[func_name/hash, args ...]" in a JSON string. If not,
then simply return the workload key as the name without arguments.
Parameters
----------
workload_key: str
The workload key in string. Format: "[func_name/hash, args ...]".
Returns
-------
name: str
The workload function name or the DAG hash.
args: Optional[List[Any]]
The arguments of the workload, or None if the workload key format is not decodeable.
"""
try:
key_list = json.loads(workload_key)
if isinstance(key_list, list) and len(key_list) >= 1:
return key_list[0], key_list[1:]
except json.decoder.JSONDecodeError:
pass
return workload_key, None


def get_func_name(func):
"""Get name of a function.
Expand Down
Loading

0 comments on commit e6d5318

Please sign in to comment.