diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 882793877ed6..8c30e673b304 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -103,8 +103,8 @@ inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, auto out_width = analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); - auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); - auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); + auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh"); + auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw"); Array out_shape = x->shape; for (size_t i = 0; i < out_shape.size(); ++i) { @@ -220,8 +220,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto out_width = analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); - auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); - auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); + auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh"); + auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw"); Array data_shape = x->shape; for (size_t i = 0; i < data_shape.size(); ++i) { @@ -245,8 +245,9 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); auto windowh = - tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh"); + auto windoww = + tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww"); auto argmax = MakeArgmaxReducer(); auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; @@ -293,8 +294,9 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, "T_pool_grad", "pool_grad_max"); } else if (pool_type == kAvgPool) { auto windowh = - tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh"); + auto windoww = + tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww"); return tvm::te::compute( data_shape, [&](const Array& inds) { @@ -696,7 +698,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, pad_tail[i] += stride[i] - 1; } - daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]))); + daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i))); pad_before.Set(ii, pad_head[i]); pad_after.Set(ii, pad_tail[i]); diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 090e6daf9859..27a30127ba65 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -658,6 +658,22 @@ class FlopEstimator : public ExprFunctor { int cur_type_code_; }; +void CheckComputeValidity(const te::Schedule& sch) { + // Check the validity of a compute definition: + // The name of each iterator should be unique. + for (auto stage : sch->stages) { + if (stage->op->IsInstance()) { + std::unordered_set names; + for (const auto& x : stage->leaf_iter_vars) { + ICHECK(!names.count(x->var->name_hint)) + << "Find duplicated iterator names in the compute definition: " << x->var->name_hint + << ". Please use different names for different iterators."; + names.insert(x->var->name_hint); + } + } + } +} + ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); node->tensors = std::move(tensors); @@ -674,6 +690,9 @@ ComputeDAG::ComputeDAG(Array tensors) { node->ops.push_back(stage->op); } + // Make sure it is a valid compute definition + CheckComputeValidity(sch); + node->flop_ct = FlopEstimator().EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); @@ -682,6 +701,9 @@ ComputeDAG::ComputeDAG(Array tensors) { ComputeDAG::ComputeDAG(const te::Schedule& sch) { auto node = make_object(); + // Make sure it is a valid compute definition + CheckComputeValidity(sch); + // Initialize ops. Here we enforce the order of ops and stages are consistent for (auto stage : sch->stages) { node->ops.push_back(stage->op); diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 5b7add9733de..87814f28ad72 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -137,6 +137,17 @@ def softmax_abcd_auto_scheduler_test(a, b, c, d): return [A, B] +@auto_scheduler.register_workload +def invalid_compute_definition(): + A = te.placeholder((10, 10), name="A") + # The names of the following two iterators are the same. + # This is invalid. + r1 = te.reduce_axis((0, 2), name="r1") + r2 = te.reduce_axis((0, 2), name="r1") + B = te.compute((10,), lambda i: te.sum(A[i][r1 + r2], axis=[r1, r2]), name="B") + return [A, B] + + @auto_scheduler.register_workload def conv2d_winograd_nhwc_auto_scheduler_test( N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1 diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index caf3c9d888b6..bde3b786d370 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -25,6 +25,7 @@ from test_auto_scheduler_common import ( get_tiled_matmul, + invalid_compute_definition, matmul_auto_scheduler_test, parallel_matmul_auto_scheduler_test, ) @@ -137,8 +138,20 @@ def test_stage_order(): assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes +def test_invalid_compute_dag(): + failed = False + try: + A, B = invalid_compute_definition() + dag = auto_scheduler.ComputeDAG([A, B]) + except tvm.TVMError as e: + failed = True + + assert failed + + if __name__ == "__main__": test_apply_steps() test_infer_bound() test_estimate_flop() test_stage_order() + test_invalid_compute_dag()