From a63d373a37f7a38fda8c712696675df8ba3a59df Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 12 Nov 2021 15:04:09 -0800 Subject: [PATCH 1/7] add support for prevously uncovered cases --- python/tvm/script/tir/__init__.pyi | 23 +++-- tests/python/unittest/test_tvmscript_type.py | 95 ++++++++++++++++++++ 2 files changed, 112 insertions(+), 6 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index fba026d414f6..c1953f6a9e83 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -69,13 +69,15 @@ class IterVar(Var): ... class Buffer: @overload - def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ... + def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ... @overload - def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ... + def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ... @overload - def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ... + def __setitem__( + self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr + ) -> None: ... @overload - def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ... + def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ... @property def data(self: Buffer) -> Ptr: ... @@ -124,12 +126,21 @@ def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... def store( var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True ) -> None: ... -def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ... +def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ... + +""" +Intrinsics - tvm builtin +""" + +def tvm_thread_allreduce( + *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str +) -> PrimExpr: ... """ Unary operator """ +def exp(x: PrimExpr) -> PrimExpr: ... def exp2(x: PrimExpr) -> PrimExpr: ... def exp10(x: PrimExpr) -> PrimExpr: ... def erf(x: PrimExpr) -> PrimExpr: ... @@ -334,7 +345,7 @@ def for_range( end: Union[PrimExpr, int] = None, annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... -def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ... +def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ... """ ty - redefine types diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 44ea04b5ed36..8a40f17575d4 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement +from tvm.ir.expr import PrimExpr from tvm.script import tir as T """ @@ -81,6 +82,100 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: ) +""" +This test case is added to test T.grid +""" + + +@T.prim_func +def loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +""" +This test case is added to test T.comm_reducer, T.reinterpret, T.tvm_thread_allreduce +""" + + +@T.prim_func +def lowered_loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + normal_reduce_temp0[0] = T.float32(0) + for ko in T.serial(0, 4): + with T.block("B_normal_reduction"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ki, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.S(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +""" +This test case is added to test T.Buffer with slice as argument +""" + + +@T.prim_func +def different_access_indices(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i, j in T.grid(128, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads([B[vi, vj], A[vi, vj, vk]]) + T.writes( + [ + B[ + T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), + T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), + ] + ] + ) + with T.init(): + B[vj, vi] = T.float32(0) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + # Not running any test as we only want to type-check here if __name__ == "__main__": pass From 5a2d54fe7c1cccd0908ebade2d9294604d309a53 Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 12 Nov 2021 15:06:56 -0800 Subject: [PATCH 2/7] remove PrimExpr import --- tests/python/unittest/test_tvmscript_type.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 8a40f17575d4..336ae3107822 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement -from tvm.ir.expr import PrimExpr from tvm.script import tir as T """ From 96740bcd41210eb6c37936f4c06a89b4490ddb2a Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 17 Nov 2021 17:25:04 -0800 Subject: [PATCH 3/7] add exp test and mypy ignore --- tests/python/unittest/test_tvmscript_type.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 336ae3107822..7d1cd14c906f 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -149,7 +149,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: """ -This test case is added to test T.Buffer with slice as argument +This test case is added to test T.Buffer with slice as argument and T.exp """ @@ -165,13 +165,13 @@ def different_access_indices(a: T.handle, b: T.handle) -> None: T.writes( [ B[ - T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), - T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), + T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), # type: ignore[misc] + T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), # type: ignore[misc] ] ] ) with T.init(): - B[vj, vi] = T.float32(0) + B[vj, vi] = T.exp(B[vi, vj]) B[vi, vj] = B[vi, vj] + A[vi, vj, vk] From 4cce43ad2c961afab24edd3fac94d553d0e228c5 Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 17 Nov 2021 19:39:37 -0800 Subject: [PATCH 4/7] disable ling too long --- tests/python/unittest/test_tvmscript_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 7d1cd14c906f..c21eadc82a2a 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement +# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement,line-too-long from tvm.script import tir as T """ From 763b85cfd4456ce96ab1c7dc1a01decf680963a0 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 18 Nov 2021 13:30:44 -0800 Subject: [PATCH 5/7] resolve long line --- tests/python/unittest/test_tvmscript_type.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index c21eadc82a2a..b3f7aa8c5a4a 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement,line-too-long +# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement from tvm.script import tir as T """ @@ -165,8 +165,10 @@ def different_access_indices(a: T.handle, b: T.handle) -> None: T.writes( [ B[ - T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), # type: ignore[misc] - T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), # type: ignore[misc] + T.min(vj, vi) : T.min(vj, vi) # type: ignore[misc] + + (T.max(vj, vi) + 1 - T.min(vj, vi)), # type: ignore[misc] + T.min(vi, vj) : T.min(vi, vj) # type: ignore[misc] + + (T.max(vi, vj) + 1 - T.min(vi, vj)), # type: ignore[misc] ] ] ) From 4d678863572e791d23ed54ba7080c35cde541c40 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 18 Nov 2021 13:31:50 -0800 Subject: [PATCH 6/7] nit --- tests/python/unittest/test_tvmscript_type.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index b3f7aa8c5a4a..2e9308dbf58f 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -166,9 +166,9 @@ def different_access_indices(a: T.handle, b: T.handle) -> None: [ B[ T.min(vj, vi) : T.min(vj, vi) # type: ignore[misc] - + (T.max(vj, vi) + 1 - T.min(vj, vi)), # type: ignore[misc] + + (T.max(vj, vi) + 1 - T.min(vj, vi)), T.min(vi, vj) : T.min(vi, vj) # type: ignore[misc] - + (T.max(vi, vj) + 1 - T.min(vi, vj)), # type: ignore[misc] + + (T.max(vi, vj) + 1 - T.min(vi, vj)), ] ] ) From 5a8b0fff0bcbdbf015dd72fadc132e6566651d45 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 18 Nov 2021 17:01:18 -0800 Subject: [PATCH 7/7] add dtype to unary ops --- python/tvm/script/tir/__init__.pyi | 51 +++++++++++--------- tests/python/unittest/test_tvmscript_type.py | 2 +- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index c1953f6a9e83..ad0a2507c709 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -138,32 +138,35 @@ def tvm_thread_allreduce( """ Unary operator +Note that any intrinsics not registered in script.tir.intrin +should add "dtype" as an argument. This is different from their +definition but intentional. """ -def exp(x: PrimExpr) -> PrimExpr: ... -def exp2(x: PrimExpr) -> PrimExpr: ... -def exp10(x: PrimExpr) -> PrimExpr: ... -def erf(x: PrimExpr) -> PrimExpr: ... -def tanh(x: PrimExpr) -> PrimExpr: ... -def sigmoid(x: PrimExpr) -> PrimExpr: ... -def log(x: PrimExpr) -> PrimExpr: ... -def log2(x: PrimExpr) -> PrimExpr: ... -def log10(x: PrimExpr) -> PrimExpr: ... -def log1p(x: PrimExpr) -> PrimExpr: ... -def tan(x: PrimExpr) -> PrimExpr: ... -def cos(x: PrimExpr) -> PrimExpr: ... -def cosh(x: PrimExpr) -> PrimExpr: ... -def acos(x: PrimExpr) -> PrimExpr: ... -def acosh(x: PrimExpr) -> PrimExpr: ... -def sin(x: PrimExpr) -> PrimExpr: ... -def sinh(x: PrimExpr) -> PrimExpr: ... -def asin(x: PrimExpr) -> PrimExpr: ... -def asinh(x: PrimExpr) -> PrimExpr: ... -def atan(x: PrimExpr) -> PrimExpr: ... -def atanh(x: PrimExpr) -> PrimExpr: ... -def atan2(x: PrimExpr) -> PrimExpr: ... -def sqrt(x: PrimExpr) -> PrimExpr: ... -def rsqrt(x: PrimExpr) -> PrimExpr: ... +def exp(x: PrimExpr, dtype: str) -> PrimExpr: ... +def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ... +def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ... +def erf(x: PrimExpr, dtype: str) -> PrimExpr: ... +def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log2(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log10(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ... +def tan(x: PrimExpr, dtype: str) -> PrimExpr: ... +def cos(x: PrimExpr, dtype: str) -> PrimExpr: ... +def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def acos(x: PrimExpr, dtype: str) -> PrimExpr: ... +def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sin(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def asin(x: PrimExpr, dtype: str) -> PrimExpr: ... +def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def atan(x: PrimExpr, dtype: str) -> PrimExpr: ... +def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... +def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... """ special_stmt - Buffers diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 2e9308dbf58f..12954e31e5ec 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -173,7 +173,7 @@ def different_access_indices(a: T.handle, b: T.handle) -> None: ] ) with T.init(): - B[vj, vi] = T.exp(B[vi, vj]) + B[vj, vi] = T.exp(B[vj, vi], dtype="float32") B[vi, vj] = B[vi, vj] + A[vi, vj, vk]