Skip to content

Commit

Permalink
remove all hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 29, 2020
1 parent 67cec0d commit c741f97
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 22 deletions.
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .relay_integration import extract_tasks
from .relay_integration import extract_tasks, remove_index_check
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
Expand Down
41 changes: 38 additions & 3 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

import tvm
from tvm import autotvm, te, transform
from tvm.te.tensor import ComputeOp, PlaceholderOp
from tvm.runtime import convert_to_object
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
from .compute_dag import ComputeDAG
from .dispatcher import DispatchContext
from .search_task import SearchTask
Expand Down Expand Up @@ -164,13 +166,13 @@ def add_workload_key(self, workload_key, ccache_key):


@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def _enter_layout_rewrite(*args):
def _enter_layout_rewrite():
env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
env.__enter__()


@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite")
def _exit_layout_rewrite(*args):
def _exit_layout_rewrite():
env = TracingEnvironment.current
env.__exit__(None, None, None)

Expand Down Expand Up @@ -284,3 +286,36 @@ def auto_schedule_topi(outs, has_complex_op):
raise ValueError("Invalid tracing mode: " + env.tracing_mode)

return schedule


def tensor_no_check_call(self, *indices):
"""An indexing function without any check.
This is the same as `tvm.te.Tensor::__call__` except that the safety
check is removed.
"""
indices = convert_to_object(indices)
args = []
for x in indices:
if isinstance(x, _expr.PrimExpr):
args.append(x)
elif isinstance(x, _expr.IterVar):
args.append(x.var)
else:
raise ValueError("The indices must be expression")

return _expr.ProducerLoad(self, args)


def remove_index_check(tensor):
"""Remove the safety check in the indexing function for a tensor.
This is done by monkey patching its indexing function.
After removing the check, we are allowed to create a
temporary wrong IR and fix it later in other places.
Parameters
----------
tensor: Tensor
The tensor to remove index check.
"""
# monkey patch the indexing function
tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)
7 changes: 3 additions & 4 deletions python/tvm/te/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __getitem__(self, indices):

def asobject(self):
"""Convert slice to object."""
return self.tensor(*self.indices)
return self.tensor.__call__(*self.indices)

@property
def dtype(self):
Expand All @@ -59,9 +59,8 @@ class Tensor(DataProducer, _expr.ExprOp):

def __call__(self, *indices):
ndim = self.ndim
# TODO(merrymercy): tmp hack for layout rewrite
# if len(indices) != ndim:
# raise ValueError("Need to provide %d index in tensor slice" % ndim)
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert_to_object(indices)
args = []
for x in indices:
Expand Down
15 changes: 2 additions & 13 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,25 +390,14 @@ def conv2d_nhwc(
num_filter = Filter.shape[5 + base] * Filter.shape[9 + base]
for i in range(base + 2):
num_filter *= Filter.shape[i]
elif len(Filter.shape) == 6:
# For cpu tile structure SRS
num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5]
kernel_h = Filter.shape[2]
kernel_w = Filter.shape[3]
channel = Filter.shape[4]
elif len(Filter.shape) == 5:
# For cpu tile structure SRS
num_filter = Filter.shape[0] * Filter.shape[4]
kernel_h = Filter.shape[1]
kernel_w = Filter.shape[2]
channel = Filter.shape[3]
elif len(Filter.shape) == 4:
num_filter, kernel_h, kernel_w, channel = Filter.shape
else:
raise ValueError(
"Don't know how to infer the layout for filter shape: %s. "
"You can add a new branch for it to fix this." % str(Filter)
"Please add a new branch to handle this case." % str(Filter)
)
auto_scheduler.remove_index_check(Filter)
else:
kernel_h, kernel_w, channel, num_filter = Filter.shape

Expand Down
1 change: 0 additions & 1 deletion src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return Type(nullptr);
}

tt1 = tt2; // TODO(merrymercy): tmp hack for layout rewrite in auto-scheduler.
tvm::Array<IndexExpr> shape;
if (tt1->shape.size() != tt2->shape.size()) {
this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span)
Expand Down
7 changes: 7 additions & 0 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

// If the layout is rewritten by auto-scheduler,
// we just forcly apply the layout provided by auto-scheduler and
// skip the normal inference logic.
if (param->auto_scheduler_rewritten_layout.size() > 0) {
return false;
}

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
if (!trans_in_layout.defined()) {
reporter->GetDiagCtx().Emit(
Expand Down

0 comments on commit c741f97

Please sign in to comment.