diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index c4ba44f67359..5b44f79ad70a 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -401,9 +401,10 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent); /*! * \brief Bind a var to thread env. * \param thread_tag The thread type tag. + * \param dtype The data type of the variable. * \return The result variable which gets bound to the thread env. */ -Var EnvThread(String thread_tag); +Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); /*! * \brief Store data in a buffer. diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 127d2a4356b1..c04ac780c9e6 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1241,7 +1241,7 @@ def launch_thread( return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member -def env_thread(thread_tag: str) -> IterVar: +def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar: """Bind a var to thread env Parameters @@ -1249,13 +1249,16 @@ def env_thread(thread_tag: str) -> IterVar: thread_tag : str The thread type tag. + dtype : str + The data type of the thread env. + Returns ------- res : IterVar The result iteration variable gets bound to the thread env. """ - return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.EnvThread(thread_tag, dtype) # type: ignore[attr-defined] # pylint: disable=no-member def buffer_store( diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index ccb5a8b57b5b..3ce5c15e6cd0 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -432,7 +432,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { } ObjectPtr n = make_object(); if (!iter_var->dom.defined()) { - const_cast(iter_var.get())->dom = Range(0, extent); + const_cast(iter_var.get())->dom = + Range(tvm::tir::make_zero(extent.dtype()), extent); } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. " << iter_var->dom->extent << " vs " << extent; @@ -444,7 +445,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { } LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) { - return LaunchThread(EnvThread(thread_tag), extent); + return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent); } RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, @@ -512,9 +513,8 @@ ElseFrame Else() { return ElseFrame(n); } -Var EnvThread(String thread_tag) { - IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex, - thread_tag); +Var EnvThread(String thread_tag, DataType dtype) { + IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag); Var var = iter_var->var; if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { opt_frame.value()->env_threads.Set(var, iter_var); diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 4c94dc04ccb6..c160e4a31dc3 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -969,9 +969,9 @@ def expected(A: T.Buffer((32, 128), "float16")): T.ptx_cp_async( "float16", A_shared.data, - T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8), + tx * T.int64(128) + cse_var_1 * T.int64(8), A.data, - T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8), + tx * T.int64(128) + cse_var_1 * T.int64(8), 16, ) T.ptx_commit_group() diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 530746a6fcb6..25a904a157da 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -471,5 +471,20 @@ def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> No tvm.ir.assert_structural_equal(func, expected) +def test_launch_thread_i64(): + """Test launching thread with int64""" + + @T.prim_func + def func() -> None: + blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1)) + if blockIdx_x == T.int64(0): + T.evaluate(T.int64(0)) + else: + T.evaluate(T.int64(1)) + + assert func.body.node.dom.min.dtype == "int64" + assert func.body.node.dom.extent.dtype == "int64" + + if __name__ == "__main__": tvm.testing.main()