From 32cc782cab742c24dbcc7e7205fb46527276e1c5 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Sat, 6 Mar 2021 12:53:23 +0800 Subject: [PATCH] [AutoSchedule] Sparse dense tuning support with custom sketch rule (#7313) * Add sparse dense tuning tutorial * Add sparse input fusion * Update the dag to support output fusion * Update * Add task input to search_task * Update * Add search_inputs to measure * Lint fix * Lint fix * Update * Update * Update * Update * Add file save load support * Update * Update * Update * Remove add_task_inputs API * Update * Update * Update * Lint fix * Lint fix * Lint fix * Lint fix * Update * Add example ci_log * Update * retrigger ci * Update * Update * Update * Lint fix * Lint fix * Lint fix --- include/tvm/auto_scheduler/measure_record.h | 2 +- include/tvm/auto_scheduler/search_task.h | 8 +- python/tvm/auto_scheduler/__init__.py | 1 + python/tvm/auto_scheduler/measure.py | 166 ++++++++- python/tvm/auto_scheduler/search_task.py | 191 +++++++++- python/tvm/auto_scheduler/utils.py | 3 + python/tvm/topi/nn/sparse.py | 118 +++++- src/auto_scheduler/feature.cc | 9 +- src/auto_scheduler/measure_record.cc | 34 ++ src/auto_scheduler/search_policy/utils.cc | 16 + src/auto_scheduler/search_task.cc | 7 +- .../unittest/test_auto_scheduler_measure.py | 32 +- .../test_auto_scheduler_search_task.py | 207 +++++++++++ .../auto_scheduler/ci_logs/sparse_dense.json | 2 + tutorials/auto_scheduler/tune_sparse_x86.py | 339 ++++++++++++++++++ 15 files changed, 1109 insertions(+), 26 deletions(-) create mode 100644 tests/python/unittest/test_auto_scheduler_search_task.py create mode 100644 tutorials/auto_scheduler/ci_logs/sparse_dense.json create mode 100644 tutorials/auto_scheduler/tune_sparse_x86.py diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index ec40611d49b4..c82ed076eca7 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,7 +34,7 @@ namespace tvm { namespace auto_scheduler { -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*) /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 9e7d3aa2cd32..14bf55abb447 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -26,6 +26,7 @@ #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ #include +#include #include namespace tvm { @@ -120,6 +121,8 @@ class SearchTaskNode : public Object { HardwareParams hardware_params; /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; + /*! \brief Names of some user defined input data used in program measuring. */ + Array task_input_names; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -128,6 +131,7 @@ class SearchTaskNode : public Object { v->Visit("target_host", &target_host); v->Visit("hardware_params", &hardware_params); v->Visit("layout_rewrite_option", &layout_rewrite_option); + v->Visit("task_input_names", &task_input_names); } static constexpr const char* _type_key = "auto_scheduler.SearchTask"; @@ -148,9 +152,11 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. * \param layout_rewrite_option The layout rewrite option used for measuring programs. + * \param task_input_names Names of some user defined input data used in program measuring. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, - Optional hardware_params, LayoutRewriteOption layout_rewrite_option); + Optional hardware_params, LayoutRewriteOption layout_rewrite_option, + Array task_input_names); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 06ca44d997e5..ff6d82a0242c 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -41,6 +41,7 @@ LocalRunner, RPCRunner, LocalRPCMeasureContext, + register_task_input_check_func, ) from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records from .relay_integration import ( diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 47ffde4327c4..959a9c5da82a 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -36,6 +36,7 @@ import shutil import tempfile import multiprocessing +import logging import tvm._ffi from tvm.runtime import Object, module, ndarray @@ -50,6 +51,7 @@ call_func_with_timeout, check_remote, get_const_tuple, + get_func_name, make_traceback_info, request_remote, ) @@ -58,6 +60,8 @@ deserialize_workload_registry_entry, ) +# pylint: disable=invalid-name +logger = logging.getLogger("auto_scheduler") # The time cost for measurements with errors # We use 1e10 instead of sys.float_info.max for better readability in log @@ -223,6 +227,7 @@ def recover_measure_input(inp, rebuild_state=False): target_host=task.target_host, hardware_params=task.hardware_params, layout_rewrite_option=task.layout_rewrite_option, + task_inputs=list(task.task_input_names), ) if rebuild_state: @@ -719,6 +724,97 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo return results +TASK_INPUT_CHECK_FUNC_REGISTRY = {} + + +def register_task_input_check_func(func_name, f=None, override=False): + """Register a function that checks the input buffer map. + + The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM + subgraph and return a Map from the input Tensor to its buffer name. + + Parameters + ---------- + func_name : Union[Function, str] + The check function that returns the compute declaration Tensors or its function name. + f : Optional[Function] + The check function to be registered. + override : boolean = False + Whether to override existing entry. + + Examples + -------- + .. code-block:: python + + @auto_scheduler.register_task_input_check_func + def check_task_input_by_placeholder_name(args : List[Tensor]): + tensor_input_map = {} + for arg in args: + if isinstance(arg.op, tvm.te.PlaceholderOp): + if arg.op.name != "placeholder": + tensor_input_map[arg] = arg.op.name + return tensor_input_map + """ + global TASK_INPUT_CHECK_FUNC_REGISTRY + + if callable(func_name): + f = func_name + func_name = get_func_name(f) + if not isinstance(func_name, str): + raise ValueError("expect string function name") + + def register(myf): + """internal register function""" + if func_name in TASK_INPUT_CHECK_FUNC_REGISTRY and not override: + raise RuntimeError("%s has been registered already" % func_name) + TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] = myf + return myf + + if f: + return register(f) + return register + + +def _prepare_input_map(args): + """This function deals with special task inputs. Map the input Tensor of a TVM subgraph + to a specific buffer name in the global buffer map. + + Parameters + ---------- + args : List[Tensor] + Input/output Tensor of a TVM subgraph. + + Returns + ------- + Dict[Tensor, str] : + Map from the input Tensor to its buffer name. + + Notes + ----- + The buffer name is specially designed, and these buffer should be provided in + `SearchTask(..., task_inputs={...})`. + """ + # pylint: disable=import-outside-toplevel + + global TASK_INPUT_CHECK_FUNC_REGISTRY + + # A dict that maps the input tensor arg to a buffer name + tensor_input_map = {} + + # Case 0: Check placeholder name + for arg in args: + if isinstance(arg.op, tvm.te.PlaceholderOp): + if arg.op.name != "placeholder": + tensor_input_map[arg] = arg.op.name + + # Case 1: Check specific tensor inputs + for func_name in TASK_INPUT_CHECK_FUNC_REGISTRY: + func = TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] + tensor_input_map.update(func(args)) + + return tensor_input_map + + def _timed_eval_func( inp_serialized, build_res, @@ -729,7 +825,11 @@ def _timed_eval_func( enable_cpu_cache_flush, verbose, ): + # pylint: disable=import-outside-toplevel + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency + inp = MeasureInput.deserialize(inp_serialized) + task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -758,11 +858,31 @@ def _timed_eval_func( if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - for arg in args: - random_fill(arg) + + tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + args = [] + task_inputs_count = 0 + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_input_names: + args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) + task_inputs_count += 1 + else: + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) + if task_inputs_count != len(task_input_names): + logger.warning( + "task_inputs not fully matched, check if there's any unexpected error" + ) ctx.sync() costs = time_f(*args).results # pylint: disable=broad-except @@ -911,7 +1031,11 @@ def _timed_rpc_run( enable_cpu_cache_flush, verbose, ): + # pylint: disable=import-outside-toplevel + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency + inp = MeasureInput.deserialize(inp_serialized) + task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -943,18 +1067,36 @@ def _timed_rpc_run( if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] - try: - random_fill = remote.get_function("tvm.contrib.random.random_fill") - except AttributeError: - raise AttributeError( - "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" + random_fill = remote.get_function("tvm.contrib.random.random_fill") + assert ( + random_fill + ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" + + tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + args = [] + task_inputs_count = 0 + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_input_names: + args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) + task_inputs_count += 1 + else: + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) + if task_inputs_count != len(task_input_names): + logger.warning( + "task_inputs not fully matched, check if there's any unexpected error" ) - for arg in args: - random_fill(arg) ctx.sync() - costs = time_f(*args).results + # clean up remote files remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 175c2fa06c39..57e239cf79e8 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -19,8 +19,12 @@ import json +import os +import logging +import numpy as np + import tvm._ffi -from tvm.runtime import Object +from tvm.runtime import Object, ndarray from tvm.driver.build_module import build from tvm.target import Target @@ -33,6 +37,9 @@ from .workload_registry import WORKLOAD_FUNC_REGISTRY, register_workload_tensors from . import _ffi_api +# pylint: disable=invalid-name +logger = logging.getLogger("auto_scheduler") + @tvm._ffi.register_object("auto_scheduler.HardwareParams") class HardwareParams(Object): @@ -157,6 +164,156 @@ def __init__( ) +# The map stores special registered buffer for measurement. +# This can be used for sparse workloads when we cannot use random tensors for measurment. +# { +# "workload_key_0": { +# "task_input_0": Tensor(...), +# "task_input_1": Tensor(...) +# }, +# "workload_key_1": { +# "task_input_2": Tensor(...), +# "task_input_3": Tensor(...) +# }, +# ... +# } +TASK_INPUT_BUFFER_TABLE = {} + + +def _save_buffer_to_file(buffer_name, buffer_data): + """Save the current Tensor buffer to a numpy file. + + File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type}.npy + """ + np_data = buffer_data.asnumpy() + + buffer_name += "." + for i in np_data.shape: + buffer_name += "%d_" % (i) + buffer_name += "%s" % (np_data.dtype) + buffer_name += ".npy" + + np_data.tofile(buffer_name, " ") + + +def _try_load_buffer_from_file(buffer_name): + """Try to load buffer from a numpy file, if not found, return None. + + File name has a same format as `_save_buffer_to_file`. + """ + filelist = os.listdir() + + for file in filelist: + if file.startswith(buffer_name + "."): + meta_info = file.split(".")[-2].split("_") + shape = [int(i) for i in meta_info[:-1]] + dtype = meta_info[-1] + buffer_data = np.fromfile(file, dtype=dtype, sep=" ") + buffer_data = buffer_data.reshape(shape) + return ndarray.array(buffer_data) + + return None + + +def register_task_input_buffer( + workload_key, + input_name, + input_data, + overwrite=False, + save_to_file=False, +): + """Register special buffer for measurement. + + Parameters + ---------- + workload_key : str + The workload key of the SearchTask. + + input_name : str + The name of input buffer. + + input_data : tvm.nd.NDArray + The input Tensor data. + + overwrite : bool = False + Whether to overwrite the data if a name has already registered. + + save_to_file : bool = False + Whether to save the data to a local file as well. This can be reused to resume the last + tuning process. + + Returns + ------- + tvm.nd.NDArray + The actual registered Tensor data of this input_name. With `overwrite` set to False, will + return the original one if the name has already registered before. + """ + global TASK_INPUT_BUFFER_TABLE + + if workload_key not in TASK_INPUT_BUFFER_TABLE: + TASK_INPUT_BUFFER_TABLE[workload_key] = {} + input_table = TASK_INPUT_BUFFER_TABLE[workload_key] + + if not overwrite: + if input_name not in input_table.keys(): + # Try to load buffer data from local file + tensor_from_file = _try_load_buffer_from_file(input_name) + if tensor_from_file: + input_table[input_name] = tensor_from_file + + if input_name in input_table.keys(): + logger.warning( + "Tensor %s exists in TASK_INPUT_BUFFER_TABLE, %s", + input_name, + "set overwrite to True or this Tensor will not be registered", + ) + return input_table[input_name] + + input_table[input_name] = input_data + if save_to_file: + _save_buffer_to_file(input_name, input_data) + return input_data + + +def get_task_input_buffer(workload_key, input_name): + """Get special buffer for measurement. + + The buffers are registered by `register_task_input_buffer`. + + Parameters + ---------- + workload_key : str + The workload key of the SearchTask. + + input_name : str + The name of input buffer. + + Returns + ------- + tvm.nd.NDArray + The registered input buffer. + """ + global TASK_INPUT_BUFFER_TABLE + + if workload_key not in TASK_INPUT_BUFFER_TABLE: + TASK_INPUT_BUFFER_TABLE[workload_key] = {} + input_table = TASK_INPUT_BUFFER_TABLE[workload_key] + + if input_name not in input_table.keys(): + # Try to load buffer data from local file + tensor_from_file = _try_load_buffer_from_file(input_name) + if tensor_from_file: + input_table[input_name] = tensor_from_file + + if input_name in input_table.keys(): + return input_table[input_name] + + raise ValueError( + "%s not found in TASK_INPUT_BUFFER_TABLE, " % (input_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + + @tvm._ffi.register_object("auto_scheduler.SearchTask") class SearchTask(Object): """The computation information and hardware parameters for a schedule search task. @@ -185,6 +342,16 @@ class SearchTask(Object): The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. + task_inputs : Union[Dict[str, tvm.nd.NDArray], List[str]] + A dict maps the input names to input tensors or a list of input names. + Some special Tensor used as inputs in program measuring. Usually we do not need to care + about it, but for special workloads like Sparse computation the Sparse Tensor input are + meaningful that we cannot use random input directly. + task_inputs_overwrite : bool = False + Whether to overwrite the data if a name has already in the global table. + task_inputs_save_to_file : bool = False + Whether to save the data to a local file as well. This can be reused to resume the last + tuning process. Examples -------- @@ -212,6 +379,9 @@ def __init__( target_host=None, hardware_params=None, layout_rewrite_option=None, + task_inputs=None, + task_inputs_overwrite=False, + task_inputs_save_to_file=False, ): assert ( func is not None or workload_key is not None @@ -231,6 +401,22 @@ def __init__( if layout_rewrite_option is None: layout_rewrite_option = LayoutRewriteOption.get_target_default(target) + task_input_names = [] + if isinstance(task_inputs, list): + task_input_names = task_inputs + elif isinstance(task_inputs, dict): + for input_name in task_inputs: + register_task_input_buffer( + workload_key, + input_name, + task_inputs[input_name], + task_inputs_overwrite, + task_inputs_save_to_file, + ) + task_input_names.append(input_name) + elif task_inputs is not None: + raise ValueError("task_inputs should be a dict or a list.") + self.__init_handle_by_constructor__( _ffi_api.SearchTask, compute_dag, @@ -239,6 +425,7 @@ def __init__( target_host, hardware_params, layout_rewrite_option, + task_input_names, ) def tune(self, tuning_options, search_policy=None): @@ -326,6 +513,7 @@ def __getstate__(self): "target_host": self.target_host, "hardware_params": self.hardware_params, "layout_rewrite_option": self.layout_rewrite_option, + "task_input_names": self.task_input_names, } def __setstate__(self, state): @@ -350,6 +538,7 @@ def __setstate__(self, state): state["target_host"], state["hardware_params"], state["layout_rewrite_option"], + state["task_input_names"], ) diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 8aa33e6775f8..14dc5b8984c3 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -201,6 +201,9 @@ def serialize_args(args): Currently this is mainly used for tvm.tensor.Tensor """ ret = [] + if args is None: + return tuple(ret) + for t in args: if isinstance(t, Tensor): t = ("TENSOR", get_const_tuple(t.shape), t.dtype) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 8145ed80af47..1bf18df09da3 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -18,7 +18,7 @@ """Sparse operators""" from __future__ import absolute_import import tvm -from tvm import te +from tvm import te, auto_scheduler from ..utils import get_const_tuple @@ -197,7 +197,7 @@ def _compute_block(nb_j, j, i): def _sparse_dense_sp_rhs_bsrmm(data, weight_data, weight_indices, weight_indptr): - (m, _) = get_const_tuple(data.shape) + (m, k) = get_const_tuple(data.shape) (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) num_blocks = num_blocks_plus_1 - 1 @@ -218,7 +218,10 @@ def _compute_block(i, nb_j, j): idxm = tvm.tir.indexmod bsrmm_block = te.compute( - (m, num_blocks, bs_r), _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block" + (m, num_blocks, bs_r), + _compute_block, + tag="sparse_dense_sp_rhs_bsrmm_block", + attrs={"FLOP": 2 * m * num_blocks * bs_r * k}, ) return te.compute( (m, num_blocks * bs_r), @@ -356,3 +359,112 @@ def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): Unlike other TOPI functions, this function operates on both graph level and operator level. """ return None + + +@auto_scheduler.register_task_input_check_func +def try_get_sparse_input(args): + """Analyze the input data from the given args. + + Parameters + ---------- + args : List[Tensor] + Input/output Tensor of a TVM subgraph. + + Returns + ------- + Dict[Tensor, str] : + Map from the input Tensor to its buffer name. + + Notes + ----- + The buffer name is specially designed, and these buffer should be provided in + `SearchTask(..., task_inputs={...})`. + """ + sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None + + def _process_inputs(input_tensors, m, n, prefix_init): + nonlocal sparse_prefix + nonlocal sparse_data + nonlocal sparse_indices + nonlocal sparse_indptr + + assert len(input_tensors) == 4 + unsure_tensors = list(input_tensors) + # Get the Dense data + dense_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 2: + assert dense_data is None + dense_data = tensor + assert m == dense_data.shape[0] + k = dense_data.shape[1] + unsure_tensors.remove(dense_data) + + # Get the Sparse data + sparse_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 3: + assert sparse_data is None + sparse_data = tensor + block_size, bs_r, bs_c = sparse_data.shape + unsure_tensors.remove(sparse_data) + + # Get the Sparse indptr & indices + sparse_indices = None + for tensor in unsure_tensors: + assert len(tensor.shape) == 1 + if tensor.shape[0] == block_size: + assert sparse_indices is None + sparse_indices = tensor + unsure_tensors.remove(sparse_indices) + assert len(unsure_tensors) == 1 + sparse_indptr = unsure_tensors[0] + + # Generate the sparse_prefix + density = 1.0 + for i in sparse_data.shape: + density *= i + density /= k * n + density = density.value + sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, m, n, k, bs_r, bs_c, density) + + visited = set() + + def _traverse(t): + # We cannot directly add tensors to the set, because the comparison of + # two tensors with ndim=0 is ambiguous. + assert t.handle is not None + if t.handle.value in visited: + return + + if isinstance(t.op, te.ComputeOp): + # TODO(jcf94): Currently only support to one sparse op, add more support here + if t.op.tag == "sparse_dense_sp_rhs_bsrmm": + m, n = t.shape + assert len(t.op.input_tensors) == 1 + block_tensor = t.op.input_tensors[0] + _process_inputs(block_tensor.op.input_tensors, m, n, "sparse_dense_bsr") + if sparse_prefix is not None: + # Early stop if we find a sparse_prefix + # Notice: If any workload has more than one sparse input, this may get problem + return + for x in t.op.input_tensors: + _traverse(x) + visited.add(t.handle.value) + + try: + for arg in args: + _traverse(arg) + # pylint: disable=broad-except + except Exception: + return {} + + if sparse_data is None or sparse_indices is None or sparse_indptr is None: + return {} + + sparse_input_map = {} + sparse_input_map[sparse_data] = sparse_prefix + "W_data" + sparse_input_map[sparse_indices] = sparse_prefix + "W_indices" + sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr" + + return sparse_input_map diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index cf516d8452e2..d93218c0208c 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1399,7 +1399,7 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, cur_inp->task->target_host, cur_inp->task->hardware_params, - cur_inp->task->layout_rewrite_option); + cur_inp->task->layout_rewrite_option, cur_inp->task->task_input_names); task_id = task_cache.size(); // compute min cost for each task @@ -1466,9 +1466,10 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, // The measure input is incomplete, rebuild task for incomplete measure pairs read from file try { Array tensors = (*workload_key_to_tensors)(workload_key); - task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option); + task = + SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, + inputs[i]->task->target_host, inputs[i]->task->hardware_params, + inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_input_names); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 1120f437b176..5dafa8d98702 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -169,6 +169,12 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(std::string("")); } writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (const auto& i : data.task_input_names) { + writer->WriteArrayItem(std::string(i)); + } + writer->EndArray(); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { @@ -200,6 +206,17 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(&int_value); data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); s = reader->NextArrayItem(); + if (s) { + reader->BeginArray(); + s = reader->NextArrayItem(); + while (s) { + reader->Read(&str_value); + data->task_input_names.push_back(str_value); + s = reader->NextArrayItem(); + } + // Process the end of array + s = reader->NextArrayItem(); + } ICHECK(!s); } } @@ -444,5 +461,22 @@ TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeMeasureInput").set_body_typed([]( reader.Read(inp.get()); return ObjectRef(inp); }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SerializeSearchTask") + .set_body_typed([](const SearchTask& search_task) { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(*search_task.get()); + return os.str(); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeSearchTask").set_body_typed([](String json) { + std::istringstream ss(json); + dmlc::JSONReader reader(&ss); + auto search_task = make_object(); + reader.Read(search_task.get()); + return ObjectRef(search_task); +}); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index d59df6965776..ce8dc39922e0 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -465,6 +465,22 @@ const std::vector& SplitFactorizationMemo::GetFactors(int n) { /********** Utils interface API for ffi **********/ +TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsGetConsumers") + .set_body_typed([](const SearchTask& task, const State& state, int stage_id) { + const std::set& consumers = GetConsumers(task, state, stage_id); + tvm::Map ret; + for (const auto& i : consumers) { + ret.Set(Integer(i), Integer(i)); + } + return ret; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsElementwiseMatch") + .set_body_typed([](const SearchTask& task, const State& state, int stage_id, + int target_stage_id) { + return ElementwiseMatch(task, state, stage_id, target_stage_id); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled") .set_body_typed([](const Stage& stage) { return IsTiled(stage); }); diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 0abee16fceab..22c2893141cf 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -114,7 +114,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - LayoutRewriteOption layout_rewrite_option) { + LayoutRewriteOption layout_rewrite_option, Array task_input_names) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -127,6 +127,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); } node->layout_rewrite_option = layout_rewrite_option; + node->task_input_names = std::move(task_input_names); data_ = std::move(node); } @@ -142,9 +143,9 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - int layout_rewrite_option) { + int layout_rewrite_option, Array task_input_names) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, - LayoutRewriteOption(layout_rewrite_option)); + LayoutRewriteOption(layout_rewrite_option), task_input_names); }); } // namespace auto_scheduler diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index cc9d7a41548d..116981028cc9 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -19,6 +19,7 @@ import json import multiprocessing +import numpy as np import tvm from tvm import topi from tvm import te, auto_scheduler @@ -26,7 +27,7 @@ import tvm.testing import pickle -from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul +from test_auto_scheduler_common import matmul_auto_scheduler_test from tvm.auto_scheduler import workload_registry @@ -355,6 +356,34 @@ def test_measure_target_host(): assert str(recovered_inp.task.target_host) == str(inp.task.target_host) +@tvm.testing.requires_llvm +def test_measure_special_inputs_map_by_name(): + @auto_scheduler.register_workload + def foo(): + X = te.placeholder(shape=[10], dtype="int32") + Index = te.placeholder(shape=[1], dtype="int32", name="Index") + Y = te.compute((1,), lambda i: X[Index[i]]) + return [X, Index, Y] + + # This workload cannot use random input for the `Index` input + task = auto_scheduler.SearchTask( + func=foo, + target="llvm", + task_inputs={ + "Index": tvm.nd.array(np.array([5], dtype="int32")), + }, + ) + + minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state) + local_builder = auto_scheduler.LocalBuilder() + local_runner = auto_scheduler.LocalRunner(timeout=10) + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + if __name__ == "__main__": test_record_split_reorder_fuse_annotation() test_record_compute_at_root_inline_cache_read_write() @@ -366,3 +395,4 @@ def test_measure_target_host(): test_dag_measure_local_builder_runner() test_measure_local_builder_rpc_runner() test_measure_target_host() + test_measure_special_inputs_map_by_name() diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py new file mode 100644 index 000000000000..78e85dc213e0 --- /dev/null +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test search policy""" + +import numpy as np +import tempfile + +import tvm +import tvm.testing +from tvm import auto_scheduler +from tvm.auto_scheduler.utils import get_const_tuple +from test_auto_scheduler_common import ( + matmul_auto_scheduler_test, + zero_rank_compute_auto_scheduler_test, + zero_rank_reduce_auto_scheduler_test, +) + + +def test_search_task_add_task_input(): + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() + N = 64 + target = "llvm" + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + test_input_1 = tvm.runtime.ndarray.empty((10, 20)) + test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50)) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + "test_input_2": test_input_2, + }, + task_inputs_overwrite=True, + ) + + assert len(task.task_input_names) == 3 + assert task.task_input_names[0] == "test_input_0" + assert task.task_input_names[1] == "test_input_1" + assert task.task_input_names[2] == "test_input_2" + + +def test_search_task_record(): + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() + N = 64 + target = "llvm" + + # Log with no task input + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", args=(N, N, N), target=target + ) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + # TODO(jcf94): Check the compute dag & hardware parameter + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + + # Log with 1 task input + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, + task_inputs={"test_input_0": test_input_0}, + task_inputs_overwrite=True, + ) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 1 + assert new_task.task_input_names[0] == "test_input_0" + + # Log with multiple task inputs + test_input_1 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + }, + task_inputs_overwrite=True, + ) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 2 + assert new_task.task_input_names[0] == "test_input_0" + assert new_task.task_input_names[1] == "test_input_1" + + # Log with version 0.5 + v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(v5_log) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 0 + + +def test_recover_measure_input_with_task_input(): + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() + + # Since this file is tests for search_task, we only check the search_task here + + # Log with no task input + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm" + ) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + + # Log with 1 task input + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, + args=(512, 512, 512), + target="llvm", + task_inputs={ + "test_input_0": test_input_0, + }, + task_inputs_overwrite=True, + ) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 1 + assert new_task.task_input_names[0] == "test_input_0" + + # Log with multiple task inputs + test_input_1 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, + args=(512, 512, 512), + target="llvm", + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + }, + task_inputs_overwrite=True, + ) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 2 + assert new_task.task_input_names[0] == "test_input_0" + assert new_task.task_input_names[1] == "test_input_1" + + # Log with version 0.5 + v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" + measure_log = auto_scheduler.measure_record.load_record_from_string(v5_log) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 0 + + +if __name__ == "__main__": + test_search_task_add_task_input() + test_search_task_record() + test_recover_measure_input_with_task_input() diff --git a/tutorials/auto_scheduler/ci_logs/sparse_dense.json b/tutorials/auto_scheduler/ci_logs/sparse_dense.json new file mode 100644 index 000000000000..7c1c100124dc --- /dev/null +++ b/tutorials/auto_scheduler/ci_logs/sparse_dense.json @@ -0,0 +1,2 @@ +# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI. +{"i": [["[\"sparse_dense\", 512, 512, 512, [9831, 16, 1], [9831], [33], \"float32\"]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1, ["sparse_dense_bsr_512_512_512_16_1_0.60_W_data", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indices", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indptr"]], [[], [["CI", 8], ["CI", 6], ["SP", 5, 0, 512, [1, 8], 1], ["FSP", 9, 0, 2, 1], ["SP", 5, 3, 32, [32], 1], ["FSP", 9, 2, 4, 1], ["RE", 5, [0, 3, 1, 4, 6, 2, 5, 7]], ["RE", 9, [0, 2, 1, 3]], ["CA", 5, 9, 1], ["CI", 4], ["FU", 9, [0, 1]], ["AN", 9, 0, 3], ["PR", 5, 0, "auto_unroll_max_step$0"], ["AN", 9, 2, 2]]]], "r": [[0.000957008], 0, 0.605709, 1614689820], "v": "v0.6"} diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py new file mode 100644 index 000000000000..ced416f6c500 --- /dev/null +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -0,0 +1,339 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-scheduling Sparse Matrix Multiplication on CPU with Custom Sketch Rule +=========================================================================== +**Author**: `Chengfan Jia `_ + +This is a tutorial on how to use the auto-scheduler to tune a sparse matrix multiplication for +CPUs. + +Auto-scheduler is designed to explore the schedule with best performance for a given computation +declaration automatically. While sometimes, we may have a demand to try some special ops which may +not been well-supported by auto-scheduler's default sketch rules and result in poor performance. +Fortunately, auto-scheduler currently allows user to provide a CustomSketch to cover these cases. + +We use sparse matrix multiplication as an example in this tutorial to demonstrate how to implement +and plug a custom sketch rule to the auto-scheduler's search policy. + +Note that this tutorial will not run on Windows or recent versions of macOS. To +get it to run, you will need to wrap the body of this tutorial in a :code:`if +__name__ == "__main__":` block. +""" + +import os +import itertools + +import numpy as np +import tvm +from tvm import te, auto_scheduler, runtime, topi +from tvm.auto_scheduler import _ffi_api +from tvm.topi.utils import get_const_tuple + +import scipy.sparse as sp + +###################################################################### +# Define the computation +# ^^^^^^^^^^^^^^^^^^^^^^ +# To begin with, let us define the computation of a sparse matmul with several relu and bias add. +# The function should return the list of input/output tensors. +# From these tensors, the auto-scheduler can get the whole computational graph. + +# We use this function to generate a random bsr matrix +def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): + import itertools + + Y = np.zeros((M, N), dtype=dtype) + assert M % BS_R == 0 + assert N % BS_C == 0 + nnz = int(density * M * N) + num_blocks = int(nnz / (BS_R * BS_C)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) + assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + for i in range(len(chosen_blocks)): + r, c = chosen_blocks[i] + Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) + s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) + assert s.data.shape == (num_blocks, BS_R, BS_C) + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (M // BS_R + 1,) + return s + + +@auto_scheduler.register_workload +def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): + X = te.placeholder(shape=(M, K), dtype=dtype) + W_data = te.placeholder(shape=w_data_shape, dtype=dtype) + W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") + W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") + B = te.placeholder(shape=(M, N), dtype=dtype) + + out = topi.nn.sparse_dense(topi.nn.relu(X), W_data, W_indices, W_indptr) + out = te.compute((M, N), lambda i, j: out[i, j] + B[i, j], name="BiasAdd") + out = topi.nn.relu(out) + + return [X, W_data, W_indices, W_indptr, B, out] + + +###################################################################### +# Special step for sparse workload +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# During schedule tuning, auto-scheduler will use random inputs to measure the performance of a +# generated schedule. While we cannot directly use a random array as the input of a sparse op, for +# the "indices" and "indptr" array are meaningful for the computation. +# +# To solve this problem, we register these as special buffers, and load them when process program +# measuring. +# See the `tvm.auto_scheduler.measure.py` for more details. + +# Define the basic shapes of this sparse computation +M = K = N = 512 +BS_R = 16 +BS_C = 1 +density = 0.6 + +# Generate the test data with numpy +X_np = np.random.randn(M, K).astype("float32") +X_np = np.maximum(np.zeros((M, K), dtype="float32"), X_np) # Relu +W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") +W_np = W_sp_np.todense() +Y_np = X_np @ W_np.T # Process the matrix multiplication +B_np = np.random.randn(M, N).astype("float32") +Y_np = Y_np + B_np # Bias add +Y_np = np.maximum(np.zeros((M, N), dtype="float32"), Y_np) # Relu + +###################################################################### +# Create the search task +# ^^^^^^^^^^^^^^^^^^^^^^ +# We then create a search task with M=N=K=512 and dtype="float32" +# If your machine supports avx instructions, you can +# +# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 +# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 + +target = tvm.target.Target("llvm") + +# Register the sparse data to task inputs +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) +task = tvm.auto_scheduler.SearchTask( + func=sparse_dense, + args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"), + target=target, + task_inputs={ + prefix + "W_data": runtime.ndarray.array(W_sp_np.data), + prefix + "W_indices": runtime.ndarray.array(W_sp_np.indices), + prefix + "W_indptr": runtime.ndarray.array(W_sp_np.indptr), + }, + task_inputs_save_to_file=True, +) + +# Inspect the computational graph +print("Computational DAG:") +print(task.compute_dag) + +###################################################################### +# Write the custom sketch for sparse dense op +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Before tuning, we will need to define the CustomSketchRule for the sparse dense op. +# +# CustomSketchRule consists of two parts: the condition function and the apply function. +# +# - condition function: describe when to apply this sketch rule. For example, we can only apply +# the rule to the sparse ops by matching their name and tag. +# - apply function: describe how to generate the initial sketch. You can implement it using +# auto-scheduler provided loop state APIs. + + +def meet_condition_func(search_policy, state, stage_id): + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if state.stages[stage_id].op.tag in [ + "sparse_dense_sp_rhs_bsrmm", + "sparse_dense_sp_rhs_bsrmm_block", + ]: + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + else: + return auto_scheduler.PreloadCustomSketchRule.PASS + + +def apply_func(search_policy, state, stage_id): + ret = [] + s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if s0.stages[stage_id].op.tag == "sparse_dense_sp_rhs_bsrmm_block": + return [s0.state_object, stage_id - 1] + + sparse_dense = s0.stages[stage_id].op + sparse_dense_block = s0.stages[stage_id - 1].op + assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm" + assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block" + + # Set the default consumer of compute block + consumer = sparse_dense + + # If sparse dense has a single elementwise consumer + # We can compute inline the sparse_dense output stage + consumers = _ffi_api.SearchPolicyUtilsGetConsumers( + search_policy.search_task, s0.state_object, stage_id + ) + if len(consumers) == 1: + consumer_id = int(consumers.items()[0][0]) + if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( + search_policy.search_task, s0.state_object, stage_id, consumer_id + ): + consumer = s0.stages[consumer_id].op + s0.compute_inline(sparse_dense) + + i, nb_j, j, row_offset, c = s0[sparse_dense_block].iters + m, n = s0[consumer].iters + i0, i1, i2 = s0.split(sparse_dense_block, i, [None, None]) + m0, m1 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 1) + j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) + n0, n1 = s0.follow_split(consumer, n, len(s0.transform_steps) - 1, 1) + s0.reorder(sparse_dense_block, [i0, j0, i1, j1, row_offset, i2, j, c]) + s0.reorder(consumer, [m0, n0, m1, n1]) + s0.compute_at(sparse_dense_block, consumer, n0) + + ret.append([s0.state_object, stage_id - 2]) + + return ret + + +###################################################################### +# Next, we set parameters for the auto-scheduler with the custom sketch plugged in. +# +# * :code:`num_measure_trials` is the number of measurement trials we can use during the search. +# We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a +# good value for the search to converge. You can do more trials according to your time budget. +# * In addition, we use :code:`RecordToFile` to dump measurement records into a file +# `sparse_dense.json`. +# The measurement records can be used to query the history best, resume the search, +# and do more analyses later. +# * see :any:`auto_scheduler.TuningOptions` for more parameters +# * Here, we need to create a :code:`auto_scheduler.SketchPolicy` object, and add the custom sketch +# rule as a `init_search_callbacks`. + +log_file = "sparse_dense.json" +tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, +) + +search_policy = auto_scheduler.SketchPolicy( + task, + program_cost_model=auto_scheduler.XGBModel(), + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") + ], +) + +###################################################################### +# Run the search +# ^^^^^^^^^^^^^^ +# Now we get all inputs ready. +# We can kick off the search and let the auto-scheduler do its magic. +# After some measurement trials, we can load the best schedule from the log +# file and apply it. + +# Run auto-tuning (search) +# Notice: We do not run the tuning in our webpage server since it takes too long. +# Uncomment the following line to run it by yourself. +task.tune(tune_option, search_policy) + +# Apply the best schedule +sch, args = task.apply_best(log_file) + +###################################################################### +# We can lower the schedule to see the IR after auto-scheduling. +# The auto-scheduler correctly performs optimizations including multi-level tiling, +# layout transformation, parallelization, vectorization, unrolling, and operator fusion. + +print("Lowered TIR:") +print(tvm.lower(sch, args, simple_mode=True)) + +###################################################################### +# Check correctness and evaluate performance +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We build the binary and check its correctness and performance. + +func = tvm.build(sch, args, target) + +ctx = tvm.cpu() + +X_tvm = tvm.nd.array(X_np, ctx=ctx) +W_data_tvm = tvm.nd.array(W_sp_np.data, ctx=ctx) +W_indices_tvm = tvm.nd.array(W_sp_np.indices, ctx=ctx) +W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, ctx=ctx) +B_tvm = tvm.nd.array(B_np, ctx=ctx) +Y_tvm = tvm.nd.empty(Y_np.shape, ctx=ctx) + +func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm) + +# Check results +tvm.testing.assert_allclose(Y_np, Y_tvm.asnumpy(), atol=1e-4, rtol=1e-4) + +# Evaluate execution time. +evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) +print( + "Execution time of this operator: %.3f ms" + % ( + np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) + * 1000 + ) +) + +###################################################################### +# .. note:: Tuning result example +# +# .. code-block:: c +# +# ---------------------------------------------------------------------- +# Lowered TIR: +# primfn(placeholder_5: handle, placeholder_6: handle, placeholder_7: handle, placeholder_8: handle, placeholder_9: handle, compute_1: handle) -> () +# attr = {"global_symbol": "main", "tir.noalias": True} +# buffers = {placeholder_2: Buffer(placeholder_10: Pointer(float32), float32, [9831, 16, 1], []), +# placeholder_4: Buffer(placeholder_11: Pointer(int32), int32, [33], []), +# placeholder_3: Buffer(placeholder_12: Pointer(float32), float32, [512, 512], []), +# compute: Buffer(compute_2: Pointer(float32), float32, [512, 512], []), +# placeholder_1: Buffer(placeholder_13: Pointer(float32), float32, [512, 512], []), +# placeholder: Buffer(placeholder_14: Pointer(int32), int32, [9831], [])} +# buffer_map = {placeholder_7: placeholder, placeholder_9: placeholder_1, placeholder_6: placeholder_2, compute_1: compute, placeholder_5: placeholder_3, placeholder_8: placeholder_4} { +# for (i0.outer.i1.outer.fused: int32, 0, 1024) "parallel" { +# attr [compute_3: Pointer(float32)] "storage_scope" = "global"; +# allocate(compute_3, float32, [256]) { +# for (nb_j.inner: int32, 0, 2) { +# for (i.inner.init: int32, 0, 8) { +# for (j.init: int32, 0, 16) { +# compute_3[(((i.inner.init*32) + (nb_j.inner*16)) + j.init)] = 0f32 +# } +# } +# for (elem_idx: int32, 0, ((int32*)placeholder_11[(((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner) + 1)] - (int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)])) { +# for (i.inner: int32, 0, 8) { +# for (j: int32, 0, 16) { +# compute_3[(((i.inner*32) + (nb_j.inner*16)) + j)] = ((float32*)compute_3[(((i.inner*32) + (nb_j.inner*16)) + j)] + ((float32*)placeholder_10[((((int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)]*16) + (elem_idx*16)) + j)]*max((float32*)placeholder_12[(((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i.inner*512)) + (int32*)placeholder_14[((int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)] + elem_idx)])], 0f32))) +# } +# } +# } +# } +# for (i0.inner: int32, 0, 8) { +# compute_2[ramp((((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i0.inner*512)) + (floormod(i0.outer.i1.outer.fused, 16)*32)), 1, 32)] = max(((float32x32*)compute_3[ramp((i0.inner*32), 1, 32)] + (float32x32*)placeholder_13[ramp((((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i0.inner*512)) + (floormod(i0.outer.i1.outer.fused, 16)*32)), 1, 32)]), broadcast(0f32, 32)) +# } +# } +# } +# }