Skip to content

Commit

Permalink
add support for prevously uncovered cases
Browse files Browse the repository at this point in the history
  • Loading branch information
shingjan committed Nov 12, 2021
1 parent 86781e9 commit df75b06
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 6 deletions.
23 changes: 17 additions & 6 deletions python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ class IterVar(Var): ...

class Buffer:
@overload
def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ...
def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ...
@overload
def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
@overload
def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
def __setitem__(
self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr
) -> None: ...
@overload
def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ...
def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
@property
def data(self: Buffer) -> Ptr: ...

Expand Down Expand Up @@ -124,12 +126,21 @@ def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
def store(
var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
) -> None: ...
def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...

"""
Intrinsics - tvm builtin
"""

def tvm_thread_allreduce(
*freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str
) -> PrimExpr: ...

"""
Unary operator
"""

def exp(x: PrimExpr) -> PrimExpr: ...
def exp2(x: PrimExpr) -> PrimExpr: ...
def exp10(x: PrimExpr) -> PrimExpr: ...
def erf(x: PrimExpr) -> PrimExpr: ...
Expand Down Expand Up @@ -334,7 +345,7 @@ def for_range(
end: Union[PrimExpr, int] = None,
annotations: Optional[Mapping[str, Object]] = None,
) -> Iterable[IterVar]: ...
def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...

"""
ty - redefine types
Expand Down
95 changes: 95 additions & 0 deletions tests/python/unittest/test_tvmscript_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
from tvm.ir.expr import PrimExpr
from tvm.script import tir as T

"""
Expand Down Expand Up @@ -81,6 +82,100 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
)


"""
This test case is added to test T.grid
"""


@T.prim_func
def loop_split(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float32")
B = T.match_buffer(b, [128], dtype="float32")
for i, ko in T.grid(128, 4):
for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
with T.block("B"):
vi = T.axis.S(128, i)
vk = T.axis.R(128, ko * 32 + ki)
T.reads([B[vi], A[vi, vk]])
T.writes([B[vi]])
with T.init():
B[vi] = T.float32(0)
B[vi] = B[vi] + A[vi, vk]


"""
This test case is added to test T.comm_reducer, T.reinterpret, T.tvm_thread_allreduce
"""


@T.prim_func
def lowered_loop_split(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float32")
B = T.match_buffer(b, [128], dtype="float32")
reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
for i in T.serial(0, 128):
for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
normal_reduce_temp0[0] = T.float32(0)
for ko in T.serial(0, 4):
with T.block("B_normal_reduction"):
vi = T.axis.S(128, i)
vk = T.axis.R(128, ko * 32 + ki)
T.reads([A[vi, vk], normal_reduce_temp0[0]])
T.writes([normal_reduce_temp0[0]])
normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
with T.block("B_cross_thread_reduction"):
T.reads([normal_reduce_temp0[0]])
T.writes([reduce_temp0[0]])
T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
normal_reduce_temp0[0],
True,
reduce_temp0.data,
ki,
dtype="handle",
)
)
with T.block("B_write_back"):
vi = T.axis.S(128, i)
T.reads([reduce_temp0[0]])
T.writes([B[vi]])
B[vi] = reduce_temp0[0]


"""
This test case is added to test T.Buffer with slice as argument
"""


@T.prim_func
def different_access_indices(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128], dtype="float32")
B = T.match_buffer(b, [128, 128], dtype="float32")
for i, j in T.grid(128, 128):
for k in T.thread_binding(0, 128, thread="threadIdx.x"):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads([B[vi, vj], A[vi, vj, vk]])
T.writes(
[
B[
T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)),
T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)),
]
]
)
with T.init():
B[vj, vi] = T.float32(0)
B[vi, vj] = B[vi, vj] + A[vi, vj, vk]


# Not running any test as we only want to type-check here
if __name__ == "__main__":
pass

0 comments on commit df75b06

Please sign in to comment.