Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript][Fix] Add type hints for more uncovered cases #9505

Merged
merged 7 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -138,32 +138,35 @@ def tvm_thread_allreduce(

"""
Unary operator
Note that any intrinsics not registered in script.tir.intrin
should add "dtype" as an argument. This is different from their
definition but intentional.
"""

def exp(x: PrimExpr) -> PrimExpr: ...
def exp2(x: PrimExpr) -> PrimExpr: ...
def exp10(x: PrimExpr) -> PrimExpr: ...
def erf(x: PrimExpr) -> PrimExpr: ...
def tanh(x: PrimExpr) -> PrimExpr: ...
def sigmoid(x: PrimExpr) -> PrimExpr: ...
def log(x: PrimExpr) -> PrimExpr: ...
def log2(x: PrimExpr) -> PrimExpr: ...
def log10(x: PrimExpr) -> PrimExpr: ...
def log1p(x: PrimExpr) -> PrimExpr: ...
def tan(x: PrimExpr) -> PrimExpr: ...
def cos(x: PrimExpr) -> PrimExpr: ...
def cosh(x: PrimExpr) -> PrimExpr: ...
def acos(x: PrimExpr) -> PrimExpr: ...
def acosh(x: PrimExpr) -> PrimExpr: ...
def sin(x: PrimExpr) -> PrimExpr: ...
def sinh(x: PrimExpr) -> PrimExpr: ...
def asin(x: PrimExpr) -> PrimExpr: ...
def asinh(x: PrimExpr) -> PrimExpr: ...
def atan(x: PrimExpr) -> PrimExpr: ...
def atanh(x: PrimExpr) -> PrimExpr: ...
def atan2(x: PrimExpr) -> PrimExpr: ...
def sqrt(x: PrimExpr) -> PrimExpr: ...
def rsqrt(x: PrimExpr) -> PrimExpr: ...
def exp(x: PrimExpr, dtype: str) -> PrimExpr: ...
def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ...
def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ...
def erf(x: PrimExpr, dtype: str) -> PrimExpr: ...
def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log2(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log10(x: PrimExpr, dtype: str) -> PrimExpr: ...
def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ...
def tan(x: PrimExpr, dtype: str) -> PrimExpr: ...
def cos(x: PrimExpr, dtype: str) -> PrimExpr: ...
def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def acos(x: PrimExpr, dtype: str) -> PrimExpr: ...
def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sin(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def asin(x: PrimExpr, dtype: str) -> PrimExpr: ...
def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def atan(x: PrimExpr, dtype: str) -> PrimExpr: ...
def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ...
def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...

"""
special_stmt - Buffers
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tvmscript_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def different_access_indices(a: T.handle, b: T.handle) -> None:
]
)
with T.init():
B[vj, vi] = T.exp(B[vi, vj])
B[vj, vi] = T.exp(B[vj, vi], dtype="float32")
shingjan marked this conversation as resolved.
Show resolved Hide resolved
B[vi, vj] = B[vi, vj] + A[vi, vj, vk]


Expand Down