Skip to content

Commit

Permalink
[TVMScript] Add for loop syntax sugar (apache#9620)
Browse files Browse the repository at this point in the history
* add for loop syntax sugar

* remove prints

* better doc

* finish thread binding

* fix CI

* fix CI

* address comments

* update sstub

* fix CI

* remove failed test

* update stub

* address comments

* add decorator
  • Loading branch information
Yuanjing Shi authored and ylc committed Jan 7, 2022
1 parent c0d7139 commit 037e94d
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 45 deletions.
40 changes: 38 additions & 2 deletions python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -322,36 +322,72 @@ def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr:
"""
Scope handler - Loops
"""

@overload
def serial(
begin: Union[PrimExpr, int],
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def serial(
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def parallel(
begin: Union[PrimExpr, int],
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def parallel(
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def vectorized(
begin: Union[PrimExpr, int],
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def vectorized(
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def unroll(
begin: Union[PrimExpr, int],
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def unroll(
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def thread_binding(
begin: Union[PrimExpr, int],
end: Union[PrimExpr, int],
thread: str,
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def thread_binding(
end: Union[PrimExpr, int],
thread: str,
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def for_range(
begin: Union[PrimExpr, int],
end: Union[PrimExpr, int] = None,
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
@overload
def for_range(
end: Union[PrimExpr, int],
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...
Expand Down
33 changes: 27 additions & 6 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,12 @@ class Serial(ForScopeHandler):
def __init__(self):
def serial(
begin: PrimExpr,
end: PrimExpr,
end: PrimExpr = None,
annotations: Optional[Mapping[str, Object]] = None,
):
if end is None:
end = begin
begin = 0
self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)

super().__init__(serial)
Expand All @@ -515,9 +518,12 @@ class Parallel(ForScopeHandler):
def __init__(self):
def parallel(
begin: PrimExpr,
end: PrimExpr,
end: PrimExpr = None,
annotations: Optional[Mapping[str, Object]] = None,
):
if end is None:
end = begin
begin = 0
self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations)

super().__init__(parallel)
Expand All @@ -530,9 +536,12 @@ class Vectorized(ForScopeHandler):
def __init__(self):
def vectorized(
begin: PrimExpr,
end: PrimExpr,
end: PrimExpr = None,
annotations: Optional[Mapping[str, Object]] = None,
):
if end is None:
end = begin
begin = 0
self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations)

super().__init__(vectorized)
Expand All @@ -545,9 +554,12 @@ class Unroll(ForScopeHandler):
def __init__(self):
def unroll(
begin: PrimExpr,
end: PrimExpr,
end: PrimExpr = None,
annotations: Optional[Mapping[str, Object]] = None,
):
if end is None:
end = begin
begin = 0
self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations)

super().__init__(unroll)
Expand All @@ -560,10 +572,19 @@ class ThreadBinding(ForScopeHandler):
def __init__(self):
def thread_binding(
begin: PrimExpr,
end: PrimExpr,
thread: str,
end: PrimExpr = None,
thread: str = None,
annotations: Optional[Mapping[str, Object]] = None,
):
if thread is None:
if isinstance(end, str): # handle case like thread_binding(128, "threadIdx.x")
thread = end
end = None
else:
raise ValueError("Thread cannot be None for thread_binding")
if end is None:
end = begin
begin = 0
thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread)
self.create_loop_info(
begin,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@
from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation, slow_summation
from .popen_pool import timeout_job

from .tir import check_error

from . import auto_scheduler
from . import autotvm
48 changes: 48 additions & 0 deletions python/tvm/testing/tir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
"""Common utility functions in TVM tir"""
import inspect
import tvm
from tvm.ir.diagnostics import override_renderer


def check_error(func, rel_lineno):
"""check if TIR script throws error"""
# Override the default renderer to accumulate errors
errors = []

def render(e):
for d in e.diagnostics:
errors.append(d)

override_renderer(render)
# The diagnostic context throws an exception when it gets an error
try:
source_code = inspect.getsource(func)
source_code = "@T.prim_func\n" + source_code
from tvm.script import from_source

# to avoid cyclic import
from_source(source_code)
except tvm.error.DiagnosticError as e:
pass
assert len(errors) == 1, errors
for d in errors:
assert (
d.span.line - 1 == rel_lineno
), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}"
38 changes: 1 addition & 37 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import tvm
from tvm import tir
from tvm.testing import check_error
from tvm.script import tir as T
from tvm.ir.diagnostics import override_renderer
import inspect
Expand All @@ -32,20 +33,6 @@ def test_buffer_bind():
check_error(buffer_bind_missing_args, 2)


def range_missing_args(a: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")

T.attr(A, "realize_scope", "")
T.realize(A[0:16, 0:16], "")
for i in T.serial(16): # error
for j in T.serial(0, 16):
A[i, j] = 0.0


def test_range_missing_args():
check_error(range_missing_args, 6)


def undefined_buffer(a: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")

Expand Down Expand Up @@ -509,29 +496,6 @@ def test_implicit_root_has_attrs():
check_error(implicit_root_has_axes, 2)


def check_error(func, rel_lineno):
# Override the default renderer to accumulate errors
errors = []

def render(e):
for d in e.diagnostics:
errors.append(d)

override_renderer(render)
# The diagnostic context throws an exception when it gets an error
try:
source_code = inspect.getsource(func)
source_code = "@T.prim_func\n" + source_code
tvm.script.from_source(source_code)
except tvm.error.DiagnosticError as e:
pass
assert len(errors) == 1, errors
for d in errors:
assert (
d.span.line - 1 == rel_lineno
), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}"


@T.prim_func
def elementwise_not_affine(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
from tvm.ir import assert_structural_equal
from tvm.script import tir as T
from tvm.testing import check_error


@T.prim_func
Expand Down Expand Up @@ -62,5 +63,43 @@ def test_reads_writes_syntax_sugar():
assert_structural_equal(transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar)


@T.prim_func
def loop_no_syntax_sugar(a: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
for i in T.serial(0, 128):
for j in T.parallel(0, 128):
for k in T.vectorized(0, 128):
for x in T.unroll(0, 128):
for y in T.thread_binding(0, 128, thread="threadIdx.x"):
for z in T.thread_binding(0, 128, thread="threadIdx.x"):
A[i, j, k, x] = A[i, j, k, x] * 2.0


@T.prim_func
def loop_syntax_sugar(a: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
for i in T.serial(128):
for j in T.parallel(128):
for k in T.vectorized(128):
for x in T.unroll(128):
for y in T.thread_binding(128, "threadIdx.x"):
for z in T.thread_binding(128, thread="threadIdx.x"):
A[i, j, k, x] = A[i, j, k, x] * 2.0


def loop_syntax_sugar_fail(a: T.handle) -> None:
A = T.match_buffer(a, (128,))
for i in T.thread_binding(128, 128):
A[i] = A[i] * 2.0


def test_loop_syntax_sugar():
assert_structural_equal(loop_no_syntax_sugar, loop_syntax_sugar)


def test_syntax_sugar_fail():
check_error(loop_syntax_sugar_fail, 3)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 037e94d

Please sign in to comment.