Skip to content

Commit

Permalink
[AutoScheduler] Check duplicated names in the compute dag (apache#6973)
Browse files Browse the repository at this point in the history
* [AutoScheduler] check duplicated names in the compute dag

* fix lint

* fix pooling

* fix pooling
  • Loading branch information
merrymercy authored and Trevor Morris committed Dec 4, 2020
1 parent 1914e04 commit 23de2d3
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 9 deletions.
20 changes: 11 additions & 9 deletions include/tvm/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
auto out_width =
analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);

auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");

Array<PrimExpr> out_shape = x->shape;
for (size_t i = 0; i < out_shape.size(); ++i) {
Expand Down Expand Up @@ -220,8 +220,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
auto out_width =
analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);

auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");

Array<PrimExpr> data_shape = x->shape;
for (size_t i = 0; i < data_shape.size(); ++i) {
Expand All @@ -245,8 +245,9 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);

auto windowh =
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
auto windoww =
tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");

auto argmax = MakeArgmaxReducer();
auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
Expand Down Expand Up @@ -293,8 +294,9 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
"T_pool_grad", "pool_grad_max");
} else if (pool_type == kAvgPool) {
auto windowh =
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
auto windoww =
tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
return tvm::te::compute(
data_shape,
[&](const Array<Var>& inds) {
Expand Down Expand Up @@ -696,7 +698,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
pad_tail[i] += stride[i] - 1;
}

daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i])));
daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i)));

pad_before.Set(ii, pad_head[i]);
pad_after.Set(ii, pad_tail[i]);
Expand Down
22 changes: 22 additions & 0 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,22 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
int cur_type_code_;
};

void CheckComputeValidity(const te::Schedule& sch) {
// Check the validity of a compute definition:
// The name of each iterator should be unique.
for (auto stage : sch->stages) {
if (stage->op->IsInstance<te::ComputeOpNode>()) {
std::unordered_set<std::string> names;
for (const auto& x : stage->leaf_iter_vars) {
ICHECK(!names.count(x->var->name_hint))
<< "Find duplicated iterator names in the compute definition: " << x->var->name_hint
<< ". Please use different names for different iterators.";
names.insert(x->var->name_hint);
}
}
}
}

ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
auto node = make_object<ComputeDAGNode>();
node->tensors = std::move(tensors);
Expand All @@ -674,6 +690,9 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
node->ops.push_back(stage->op);
}

// Make sure it is a valid compute definition
CheckComputeValidity(sch);

node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
node->init_state = State(node->ops);
data_ = std::move(node);
Expand All @@ -682,6 +701,9 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
ComputeDAG::ComputeDAG(const te::Schedule& sch) {
auto node = make_object<ComputeDAGNode>();

// Make sure it is a valid compute definition
CheckComputeValidity(sch);

// Initialize ops. Here we enforce the order of ops and stages are consistent
for (auto stage : sch->stages) {
node->ops.push_back(stage->op);
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_auto_scheduler_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ def softmax_abcd_auto_scheduler_test(a, b, c, d):
return [A, B]


@auto_scheduler.register_workload
def invalid_compute_definition():
A = te.placeholder((10, 10), name="A")
# The names of the following two iterators are the same.
# This is invalid.
r1 = te.reduce_axis((0, 2), name="r1")
r2 = te.reduce_axis((0, 2), name="r1")
B = te.compute((10,), lambda i: te.sum(A[i][r1 + r2], axis=[r1, r2]), name="B")
return [A, B]


@auto_scheduler.register_workload
def conv2d_winograd_nhwc_auto_scheduler_test(
N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_auto_scheduler_compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from test_auto_scheduler_common import (
get_tiled_matmul,
invalid_compute_definition,
matmul_auto_scheduler_test,
parallel_matmul_auto_scheduler_test,
)
Expand Down Expand Up @@ -137,8 +138,20 @@ def test_stage_order():
assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes


def test_invalid_compute_dag():
failed = False
try:
A, B = invalid_compute_definition()
dag = auto_scheduler.ComputeDAG([A, B])
except tvm.TVMError as e:
failed = True

assert failed


if __name__ == "__main__":
test_apply_steps()
test_infer_bound()
test_estimate_flop()
test_stage_order()
test_invalid_compute_dag()

0 comments on commit 23de2d3

Please sign in to comment.