Skip to content

Commit

Permalink
update rewrite_compute_body
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 29, 2020
1 parent c741f97 commit ebabcb4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
4 changes: 2 additions & 2 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# Shortcut
from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule
from .compute_dag import ComputeDAG, rewrite_compute_body
from .compute_dag import ComputeDAG
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .measure import (
Expand All @@ -44,7 +44,7 @@
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .relay_integration import extract_tasks, remove_index_check
from .relay_integration import extract_tasks, remove_index_check, rewrite_compute_body
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
Expand Down
17 changes: 2 additions & 15 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def rewrite_layout_from_state(self, state):
Returns
-------
updated_state : StateObject
updated_dag : ComputeDAG
The compute dag with rewritten layout.
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)
Expand Down Expand Up @@ -226,17 +227,3 @@ def __setstate__(self, state):
self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return
self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche)


def rewrite_compute_body(compute, placeholder, new_layout):
"""Rewrite the body of a ComputeOp according to a new layout of a placeholder"""
body = []
for b in compute.op.body:
body.append(_ffi_api.RewriteIndexForNewLayout(placeholder.op, new_layout, b))
op_node = tvm.te._ffi_api.ComputeOp(
compute.op.name, compute.op.tag, compute.op.attrs, compute.op.axis, body
)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
29 changes: 26 additions & 3 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tvm.runtime import convert_to_object
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
from . import _ffi_api
from .compute_dag import ComputeDAG
from .dispatcher import DispatchContext
from .search_task import SearchTask
Expand Down Expand Up @@ -166,13 +167,15 @@ def add_workload_key(self, workload_key, ccache_key):


@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def _enter_layout_rewrite():
def enter_layout_rewrite():
"""Enter layout rewrite tracing environment"""
env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
env.__enter__()


@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite")
def _exit_layout_rewrite():
def exit_layout_rewrite():
"""Exit layout rewrite tracing environment"""
env = TracingEnvironment.current
env.__exit__(None, None, None)

Expand Down Expand Up @@ -317,5 +320,25 @@ def remove_index_check(tensor):
tensor: Tensor
The tensor to remove index check.
"""
# monkey patch the indexing function
# Monkey patch the indexing function
tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)


def rewrite_compute_body(compute_tensor, new_layout):
"""Rewrite the body of a ComputeOp according to a new layout of a placeholder"""
op = compute_tensor.op

# Get layout free placeholders
layout_free_placeholders = op.attrs["layout_free_placeholders"]
assert len(layout_free_placeholders) == 1
placeholder_op = layout_free_placeholders[0].op

# Rewrite the index expression in body
body = []
for b in op.body:
body.append(_ffi_api.RewriteIndexForNewLayout(placeholder_op, new_layout, b))
op_node = tvm.te._ffi_api.ComputeOp(op.name, op.tag, op.attrs, op.axis, body)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
6 changes: 2 additions & 4 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def conv2d_nhwc(
dilation_h, dilation_w = dilation

if auto_scheduler_rewritten_layout:
# infer shape for the rewritten layout
# Infer shape for the rewritten layout
if len(Filter.shape) >= 10:
# For cpu tile structure SSRSRS
base = len(Filter.shape) - 10
Expand Down Expand Up @@ -432,9 +432,7 @@ def conv2d_nhwc(
)

if auto_scheduler_rewritten_layout:
Output = auto_scheduler.rewrite_compute_body(
Output, Filter, auto_scheduler_rewritten_layout
)
Output = auto_scheduler.rewrite_compute_body(Output, auto_scheduler_rewritten_layout)

return Output

Expand Down

0 comments on commit ebabcb4

Please sign in to comment.