Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Oct 24, 2020
1 parent c784fba commit 5042c6a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
if (step->IsInstance<ComputeAtStepNode>()) {
auto compute_at_step = tvm::Downcast<ComputeAtStep>(step);
if (compute_at_step->target_stage_id >= static_cast<int>(stage_id)) {
dynamic_cast<ComputeAtStepNode*>(step.CopyOnWrite())->target_stage_id++;
dynamic_cast<ComputeAtStepNode*>(compute_at_step.CopyOnWrite())->target_stage_id++;
}
transform_steps->Set(i, std::move(compute_at_step));
} else {
Expand Down Expand Up @@ -1101,7 +1101,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
if (layout_rewrite != LayoutRewriteOption::NoRewrite && !transform_steps.empty()) {
Array<Step> steps = transform_steps;
const auto& dag = RewriteLayout(&steps, layout_rewrite);
return dag.ApplySteps(steps, stages, stage_to_axes, LayoutRewriteOption::NoRewrite);
return dag.ApplySteps(steps);
}

// Temporal object to be used if the input pointer is nullptr
Expand Down
17 changes: 7 additions & 10 deletions tests/python/unittest/test_auto_scheduler_layout_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def test_correctness_layout_rewrite_with_placeholder():
func_ref(*args_ref)
ctx.sync()

np.testing.assert_allclose(np_args[0], np_args_ref[0])
np.testing.assert_allclose(np_args[2], np_args_ref[2])
np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy())
np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy())


def test_correctness_layout_rewrite_with_pre_transpose():
Expand All @@ -133,14 +133,9 @@ def test_correctness_layout_rewrite_with_pre_transpose():
inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
s, bufs = dag.apply_steps_from_state(inp.state,
layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"])
print(bufs)
print("<<<")
print(tvm.lower(s, bufs, simple_mode=True))
exit(0)

s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
np_args_ref = [np.array(x) for x in np_args]

func = tvm.build(s, bufs, target=target)
func_ref = tvm.build(s_ref, bufs_ref, target=target)
Expand All @@ -149,17 +144,19 @@ def test_correctness_layout_rewrite_with_pre_transpose():
ctx_ref = tvm.cpu()

args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args]
ctx.sync()

func(*args)
func_ref(*args_ref)
ctx.sync()

np.testing.assert_allclose(np_args, np_args_ref)
np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy())
np.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy())
np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy())


if __name__ == "__main__":
test_apply_steps_with_layout_rewrite()
# test_correctness_layout_rewrite_with_placeholder()
test_correctness_layout_rewrite_with_placeholder()
test_correctness_layout_rewrite_with_pre_transpose()

0 comments on commit 5042c6a

Please sign in to comment.