Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add structural error printing for TensorIR #9306

Merged
merged 16 commits into from
Oct 23, 2021

Conversation

shingjan
Copy link
Contributor

@shingjan shingjan commented Oct 18, 2021

This PR intends to improve the error rendering by annotating regions of interest in TIR like synr following PR #8121.
Previously when there is an error:

The IR is:
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(a: tir.handle, b: tir.handle) -> None:
        A = tir.match_buffer(a, [128, 128, 128, 128], dtype="float32")
        B = tir.match_buffer(b, [128, 128, 128, 128], dtype="float32")
        # body
        # with tir.block("root")
        for i, j, k, l in tir.grid(128, 128, 128, 8):
            with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                tir.bind(vl, l * 16)
                tir.reads([A[vi, vj, vk, vl]])
                tir.writes([B[vi, vj, vk, vl]])
                B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * tir.float32(2)
    
Regions of interest:
tir.Block#0
block B(iter_var(vi, range(min=0, ext=128)), iter_var(vj, range(min=0, ext=128)), iter_var(vk, range(min=0, ext=128)), iter_var(vl, range(min=0, ext=128))) {
  reads([A[vi, vj, vk, vl]])
  writes([B[vi, vj, vk, vl]])
  B[vi, vj, vk, vl] = (A[vi, vj, vk, vl]*2f)
}

With this PR, the error will be rendered like below:

The IR with diagnostic is:
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(a: tir.handle, b: tir.handle) -> None:
        A = tir.match_buffer(a, [128, 128, 128, 128], dtype="float32")
        B = tir.match_buffer(b, [128, 128, 128, 128], dtype="float32")
        # body
        # with tir.block("root")
        for i, j, k, l in tir.grid(128, 128, 128, 8):
            tir.Block#0
            with tir.block("B"):
            ^^^^^^^^^^^^^^^^^^^^
                vi, vj, vk = tir.axis.remap("SSS", [i, j, k])
                vl = tir.axis.spatial(128, l * 16)
                tir.reads([A[vi, vj, vk, vl]])
                tir.writes([B[vi, vj, vk, vl]])
                B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * tir.float32(2)

cc: @vinx13 @junrushao1994

Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I left some comments

src/tir/schedule/error.cc Outdated Show resolved Hide resolved
src/tir/schedule/error.cc Outdated Show resolved Hide resolved
@junrushao
Copy link
Member

Hey thanks for the PR! It's a pretty nice POC, and I would like to deliberate on the format :-)

  • Shall we update the syntax in the description to reflect the latest change on mainline? Also please add a few unittests
  • If the error location is at a particular loop, let's avoid fusing it with other loops with tir.grid
  • Also we should consider printing the names of the error locations under the annotations, so that the error message could be clearly referring to these blocks/loops

@junrushao
Copy link
Member

junrushao commented Oct 18, 2021

Hmm I just went through the code, but figured that the snippet in the description isn't the actual output of the PR (Sorry I was wrong)? Shall we update the description with a real-world example?

We have plenty of these errors in tests/python/unittest/test_tir_schedule_*.py, guarded by:

with pytest.raises(tvm.tir.ScheduleError, ...):`
  ...

src/printer/tvmscript_printer.cc Outdated Show resolved Hide resolved
src/printer/tvmscript_printer.cc Outdated Show resolved Hide resolved
Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments

@Hzfengsy
Copy link
Member

Looks good to me. One thing I'm not sure is that will it compatible with the new block syntax.

src/tir/schedule/error.cc Outdated Show resolved Hide resolved
src/printer/tvmscript_printer.cc Outdated Show resolved Hide resolved
@tqchen
Copy link
Member

tqchen commented Oct 19, 2021

@shingjan you may need to update the test cases and code now that the new block syntax is introduced

@shingjan shingjan force-pushed the tir_structural_error branch 2 times, most recently from 5191293 to 6e0e9c4 Compare October 21, 2021 23:00
@shingjan shingjan requested a review from vinx13 October 21, 2021 23:02
@shingjan
Copy link
Contributor Author

Printing nested loop is fixed as well as some comments addressed. Should be good for another good before we merge this in. @vinx13 @junrushao1994

@vinx13 vinx13 self-assigned this Oct 21, 2021
@shingjan shingjan requested a review from vinx13 October 22, 2021 00:25
@shingjan shingjan requested a review from vinx13 October 22, 2021 00:43
src/printer/tvmscript_printer.cc Outdated Show resolved Hide resolved
src/printer/tvmscript_printer.cc Outdated Show resolved Hide resolved
src/printer/tvmscript_printer.cc Outdated Show resolved Hide resolved
@shingjan shingjan requested a review from vinx13 October 22, 2021 01:02
Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise LGTM

tests/python/unittest/test_tvmscript_error_report.py Outdated Show resolved Hide resolved
src/tir/schedule/error.cc Outdated Show resolved Hide resolved
src/printer/tvmscript_printer.cc Outdated Show resolved Hide resolved
Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @Hzfengsy would you like to take a second look?

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Hzfengsy Hzfengsy merged commit e9a66a1 into apache:main Oct 23, 2021
@shingjan shingjan deleted the tir_structural_error branch October 26, 2021 00:37
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 7, 2022
* add structural error printing

* remove old code

* address comments

* address comments

* add test

* fix test case

* fix nested loop

* rm print

* change simple loop cond

* address comments

* fix test

* address comments

* remove msg

* add override

* address comments

* address comments
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
* add structural error printing

* remove old code

* address comments

* address comments

* add test

* fix test case

* fix nested loop

* rm print

* change simple loop cond

* address comments

* fix test

* address comments

* remove msg

* add override

* address comments

* address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants