Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Support T.launch_thread with i64 dtype #16916

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,10 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);
/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
* \param dtype The data type of the variable.
* \return The result variable which gets bound to the thread env.
*/
Var EnvThread(String thread_tag);
Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));

/*!
* \brief Store data in a buffer.
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,21 +1241,24 @@ def launch_thread(
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member


def env_thread(thread_tag: str) -> IterVar:
def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar:
"""Bind a var to thread env

Parameters
----------
thread_tag : str
The thread type tag.

dtype : str
The data type of the thread env.

Returns
-------
res : IterVar
The result iteration variable gets bound to the thread env.

"""
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.EnvThread(thread_tag, dtype) # type: ignore[attr-defined] # pylint: disable=no-member


def buffer_store(
Expand Down
10 changes: 5 additions & 5 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}
ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
if (!iter_var->dom.defined()) {
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom =
Range(tvm::tir::make_zero(extent.dtype()), extent);
} else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
<< iter_var->dom->extent << " vs " << extent;
Expand All @@ -444,7 +445,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}

LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
return LaunchThread(EnvThread(thread_tag), extent);
return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent);
}

RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
Expand Down Expand Up @@ -512,9 +513,8 @@ ElseFrame Else() {
return ElseFrame(n);
}

Var EnvThread(String thread_tag) {
IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
thread_tag);
Var EnvThread(String thread_tag, DataType dtype) {
IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag);
Var var = iter_var->var;
if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
opt_frame.value()->env_threads.Set(var, iter_var);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,9 @@ def expected(A: T.Buffer((32, 128), "float16")):
T.ptx_cp_async(
"float16",
A_shared.data,
T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
tx * T.int64(128) + cse_var_1 * T.int64(8),
A.data,
T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
tx * T.int64(128) + cse_var_1 * T.int64(8),
16,
)
T.ptx_commit_group()
Expand Down
15 changes: 15 additions & 0 deletions tests/python/tvmscript/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,5 +471,20 @@ def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> No
tvm.ir.assert_structural_equal(func, expected)


def test_launch_thread_i64():
"""Test launching thread with int64"""

@T.prim_func
def func() -> None:
blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
if blockIdx_x == T.int64(0):
T.evaluate(T.int64(0))
else:
T.evaluate(T.int64(1))

assert func.body.node.dom.min.dtype == "int64"
assert func.body.node.dom.extent.dtype == "int64"


if __name__ == "__main__":
tvm.testing.main()
Loading