diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 4fbeb306d74a8..5bf2335ec7cfd 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -32,7 +32,7 @@ # Shortcut from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule -from .compute_dag import ComputeDAG, rewrite_compute_body +from .compute_dag import ComputeDAG from .cost_model import RandomModel, XGBModel from .dispatcher import DispatchContext, ApplyHistoryBest from .measure import ( @@ -44,7 +44,7 @@ LocalRPCMeasureContext, ) from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records -from .relay_integration import extract_tasks, remove_index_check +from .relay_integration import extract_tasks, remove_index_check, rewrite_compute_body from .search_task import SearchTask from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates from .task_scheduler import TaskScheduler diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 806b37d697794..c1a195f3c8fe9 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -173,7 +173,8 @@ def rewrite_layout_from_state(self, state): Returns ------- - updated_state : StateObject + updated_dag : ComputeDAG + The compute dag with rewritten layout. """ state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj) @@ -226,17 +227,3 @@ def __setstate__(self, state): self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche) - - -def rewrite_compute_body(compute, placeholder, new_layout): - """Rewrite the body of a ComputeOp according to a new layout of a placeholder""" - body = [] - for b in compute.op.body: - body.append(_ffi_api.RewriteIndexForNewLayout(placeholder.op, new_layout, b)) - op_node = tvm.te._ffi_api.ComputeOp( - compute.op.name, compute.op.tag, compute.op.attrs, compute.op.axis, body - ) - - num = op_node.num_outputs - outputs = tuple(op_node.output(i) for i in range(num)) - return outputs[0] if num == 1 else outputs diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index f2822efc3e75d..9cd7b43d8067b 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -31,6 +31,7 @@ from tvm.runtime import convert_to_object from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import expr as _expr +from . import _ffi_api from .compute_dag import ComputeDAG from .dispatcher import DispatchContext from .search_task import SearchTask @@ -166,13 +167,15 @@ def add_workload_key(self, workload_key, ccache_key): @tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite") -def _enter_layout_rewrite(): +def enter_layout_rewrite(): + """Enter layout rewrite tracing environment""" env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE) env.__enter__() @tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite") -def _exit_layout_rewrite(): +def exit_layout_rewrite(): + """Exit layout rewrite tracing environment""" env = TracingEnvironment.current env.__exit__(None, None, None) @@ -317,5 +320,25 @@ def remove_index_check(tensor): tensor: Tensor The tensor to remove index check. """ - # monkey patch the indexing function + # Monkey patch the indexing function tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor) + + +def rewrite_compute_body(compute_tensor, new_layout): + """Rewrite the body of a ComputeOp according to a new layout of a placeholder""" + op = compute_tensor.op + + # Get layout free placeholders + layout_free_placeholders = op.attrs["layout_free_placeholders"] + assert len(layout_free_placeholders) == 1 + placeholder_op = layout_free_placeholders[0].op + + # Rewrite the index expression in body + body = [] + for b in op.body: + body.append(_ffi_api.RewriteIndexForNewLayout(placeholder_op, new_layout, b)) + op_node = tvm.te._ffi_api.ComputeOp(op.name, op.tag, op.attrs, op.axis, body) + + num = op_node.num_outputs + outputs = tuple(op_node.output(i) for i in range(num)) + return outputs[0] if num == 1 else outputs diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 69d5c62359cbd..492b62b3e21d7 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -380,7 +380,7 @@ def conv2d_nhwc( dilation_h, dilation_w = dilation if auto_scheduler_rewritten_layout: - # infer shape for the rewritten layout + # Infer shape for the rewritten layout if len(Filter.shape) >= 10: # For cpu tile structure SSRSRS base = len(Filter.shape) - 10 @@ -432,9 +432,7 @@ def conv2d_nhwc( ) if auto_scheduler_rewritten_layout: - Output = auto_scheduler.rewrite_compute_body( - Output, Filter, auto_scheduler_rewritten_layout - ) + Output = auto_scheduler.rewrite_compute_body(Output, auto_scheduler_rewritten_layout) return Output