diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index a62aac3090dbf..3e916f9e82a5c 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -961,7 +961,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, if (step->IsInstance()) { auto compute_at_step = tvm::Downcast(step); if (compute_at_step->target_stage_id >= static_cast(stage_id)) { - dynamic_cast(step.CopyOnWrite())->target_stage_id++; + dynamic_cast(compute_at_step.CopyOnWrite())->target_stage_id++; } transform_steps->Set(i, std::move(compute_at_step)); } else { @@ -1101,7 +1101,7 @@ std::pair> ComputeDAG::ApplySteps( if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) { Array steps = transform_steps; const auto& dag = RewriteLayout(&steps, layout_rewrite); - return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite); + return dag.ApplySteps(steps); } // Temporal object to be used if the input pointer is nullptr diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index d51a0f55e5e66..356743861bef2 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -107,8 +107,8 @@ def test_correctness_layout_rewrite_with_placeholder(): func_ref(*args_ref) ctx.sync() - np.testing.assert_allclose(np_args[0], np_args_ref[0]) - np.testing.assert_allclose(np_args[2], np_args_ref[2]) + np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) + np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) def test_correctness_layout_rewrite_with_pre_transpose(): @@ -133,14 +133,9 @@ def test_correctness_layout_rewrite_with_pre_transpose(): inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"]) - print(bufs) - print("<<<") - print(tvm.lower(s, bufs, simple_mode=True)) - exit(0) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] - np_args_ref = [np.array(x) for x in np_args] func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) @@ -149,17 +144,19 @@ def test_correctness_layout_rewrite_with_pre_transpose(): ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] - args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] + args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() - np.testing.assert_allclose(np_args, np_args_ref) + np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy()) + np.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy()) + np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy()) if __name__ == "__main__": test_apply_steps_with_layout_rewrite() - # test_correctness_layout_rewrite_with_placeholder() + test_correctness_layout_rewrite_with_placeholder() test_correctness_layout_rewrite_with_pre_transpose()