Skip to content

Commit

Permalink
[AutoScheduler] Bug fix for layout rewrite CI error in i386 (#6830)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 authored Nov 4, 2020
1 parent 47d9415 commit b8761ed
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions tests/python/unittest/test_auto_scheduler_layout_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

import tvm
import tvm.testing
from tvm import topi
from tvm import auto_scheduler, te

Expand All @@ -48,7 +49,6 @@ def test_apply_steps_with_layout_rewrite():
assert bufs[1].shape[1] == 512


@pytest.mark.skip("skip due to flaky")
@tvm.testing.requires_llvm
def test_correctness_layout_rewrite_rewrite_for_preTransformed():
N = 128
Expand Down Expand Up @@ -114,12 +114,11 @@ def test_correctness_layout_rewrite_rewrite_for_preTransformed():
func_ref(*args_ref)
ctx.sync()

tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), rtol=1e-4)
tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), rtol=1e-4)
tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3)
tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3)
del measure_ctx


@pytest.mark.skip("skip due to flaky")
@tvm.testing.requires_llvm
def test_correctness_layout_rewrite_insert_transform_stage():
N = 128
Expand Down Expand Up @@ -162,14 +161,13 @@ def test_correctness_layout_rewrite_insert_transform_stage():
func_ref(*args_ref)
ctx.sync()

tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), rtol=1e-4)
tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), rtol=1e-4)
tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), rtol=1e-4)
tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3)
tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), atol=1e-3, rtol=1e-3)
tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3)
del measure_ctx


if __name__ == "__main__":
test_apply_steps_with_layout_rewrite()
# Disable for now due to being flaky on i386
# test_correctness_layout_rewrite_rewrite_for_preTransformed()
# test_correctness_layout_rewrite_insert_transform_stage()
test_correctness_layout_rewrite_rewrite_for_preTransformed()
test_correctness_layout_rewrite_insert_transform_stage()

0 comments on commit b8761ed

Please sign in to comment.