diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index fba026d414f6..ad0a2507c709 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -69,13 +69,15 @@ class IterVar(Var): ... class Buffer: @overload - def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ... + def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ... @overload - def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ... + def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ... @overload - def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ... + def __setitem__( + self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr + ) -> None: ... @overload - def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ... + def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ... @property def data(self: Buffer) -> Ptr: ... @@ -124,35 +126,47 @@ def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... def store( var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True ) -> None: ... -def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ... +def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ... + +""" +Intrinsics - tvm builtin +""" + +def tvm_thread_allreduce( + *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str +) -> PrimExpr: ... """ Unary operator +Note that any intrinsics not registered in script.tir.intrin +should add "dtype" as an argument. This is different from their +definition but intentional. """ -def exp2(x: PrimExpr) -> PrimExpr: ... -def exp10(x: PrimExpr) -> PrimExpr: ... -def erf(x: PrimExpr) -> PrimExpr: ... -def tanh(x: PrimExpr) -> PrimExpr: ... -def sigmoid(x: PrimExpr) -> PrimExpr: ... -def log(x: PrimExpr) -> PrimExpr: ... -def log2(x: PrimExpr) -> PrimExpr: ... -def log10(x: PrimExpr) -> PrimExpr: ... -def log1p(x: PrimExpr) -> PrimExpr: ... -def tan(x: PrimExpr) -> PrimExpr: ... -def cos(x: PrimExpr) -> PrimExpr: ... -def cosh(x: PrimExpr) -> PrimExpr: ... -def acos(x: PrimExpr) -> PrimExpr: ... -def acosh(x: PrimExpr) -> PrimExpr: ... -def sin(x: PrimExpr) -> PrimExpr: ... -def sinh(x: PrimExpr) -> PrimExpr: ... -def asin(x: PrimExpr) -> PrimExpr: ... -def asinh(x: PrimExpr) -> PrimExpr: ... -def atan(x: PrimExpr) -> PrimExpr: ... -def atanh(x: PrimExpr) -> PrimExpr: ... -def atan2(x: PrimExpr) -> PrimExpr: ... -def sqrt(x: PrimExpr) -> PrimExpr: ... -def rsqrt(x: PrimExpr) -> PrimExpr: ... +def exp(x: PrimExpr, dtype: str) -> PrimExpr: ... +def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ... +def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ... +def erf(x: PrimExpr, dtype: str) -> PrimExpr: ... +def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log2(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log10(x: PrimExpr, dtype: str) -> PrimExpr: ... +def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ... +def tan(x: PrimExpr, dtype: str) -> PrimExpr: ... +def cos(x: PrimExpr, dtype: str) -> PrimExpr: ... +def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def acos(x: PrimExpr, dtype: str) -> PrimExpr: ... +def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sin(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def asin(x: PrimExpr, dtype: str) -> PrimExpr: ... +def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def atan(x: PrimExpr, dtype: str) -> PrimExpr: ... +def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ... +def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ... +def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... +def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... """ special_stmt - Buffers @@ -334,7 +348,7 @@ def for_range( end: Union[PrimExpr, int] = None, annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... -def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ... +def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ... """ ty - redefine types diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py index 44ea04b5ed36..12954e31e5ec 100644 --- a/tests/python/unittest/test_tvmscript_type.py +++ b/tests/python/unittest/test_tvmscript_type.py @@ -81,6 +81,102 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: ) +""" +This test case is added to test T.grid +""" + + +@T.prim_func +def loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +""" +This test case is added to test T.comm_reducer, T.reinterpret, T.tvm_thread_allreduce +""" + + +@T.prim_func +def lowered_loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + normal_reduce_temp0[0] = T.float32(0) + for ko in T.serial(0, 4): + with T.block("B_normal_reduction"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ki, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.S(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +""" +This test case is added to test T.Buffer with slice as argument and T.exp +""" + + +@T.prim_func +def different_access_indices(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i, j in T.grid(128, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads([B[vi, vj], A[vi, vj, vk]]) + T.writes( + [ + B[ + T.min(vj, vi) : T.min(vj, vi) # type: ignore[misc] + + (T.max(vj, vi) + 1 - T.min(vj, vi)), + T.min(vi, vj) : T.min(vi, vj) # type: ignore[misc] + + (T.max(vi, vj) + 1 - T.min(vi, vj)), + ] + ] + ) + with T.init(): + B[vj, vi] = T.exp(B[vj, vi], dtype="float32") + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + # Not running any test as we only want to type-check here if __name__ == "__main__": pass