diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index ec348f8b57f6..b908df2f869b 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -48,6 +48,23 @@ if(USE_ROCM) list(APPEND RUNTIME_SRCS ${ROCBLAS_CONTRIB_SRCS}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY}) endif(USE_ROCBLAS) + + if(USE_THRUST) + message(STATUS "Build with rocThrust support") + # We need to override CXX to hipcc. This is required by rocthrust + if (${CMAKE_CXX_COMPILER} MATCHES "hipcc$") + message(STATUS "Using hipcc compiler to compile rocthrust code.") + else() + message(FATAL_ERROR "Set CXX=hipcc to compile rocthrust code.") + endif() + + find_package(rocprim REQUIRED) + find_package(rocthrust REQUIRED) + set_source_files_properties(src/runtime/contrib/thrust/thrust.cu PROPERTIES LANGUAGE CXX) + list(APPEND RUNTIME_SRCS src/runtime/contrib/thrust/thrust.cu) + list(APPEND TVM_RUNTIME_LINKER_LIBS roc::rocthrust) + endif(USE_THRUST) + else(USE_ROCM) list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc) endif(USE_ROCM) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index c52da541a8ab..934f38625fd3 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -18,6 +18,8 @@ # pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled +from tvm.te import SpecializedCondition +from tvm._ffi import get_global_func from .generic import * from .. import op as _op from .cuda import judge_winograd, naive_schedule @@ -219,3 +221,93 @@ def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): plevel=12, ) return strategy + + +def can_use_thrust(target, func_name): + return ( + target.kind.name == "rocm" + and "thrust" in target.libs + and get_global_func(func_name, allow_missing=True) + ) + + +@argsort_strategy.register(["rocm"]) +def argsort_strategy_cuda(attrs, inputs, out_type, target): + """argsort rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_argsort(topi.cuda.argsort), + wrap_topi_schedule(topi.cuda.schedule_argsort), + name="argsort.rocm", + ) + if can_use_thrust(target, "tvm.contrib.thrust.sort"): + strategy.add_implementation( + wrap_compute_argsort(topi.cuda.argsort_thrust), + wrap_topi_schedule(topi.cuda.schedule_argsort), + name="argsort_thrust.rocm", + plevel=15, + ) + return strategy + + +@scatter_strategy.register(["rocm"]) +def scatter_cuda(attrs, inputs, out_type, target): + """scatter rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter), + wrap_topi_schedule(topi.cuda.schedule_scatter), + name="scatter.rocm", + plevel=10, + ) + + rank = len(inputs[0].shape) + + with SpecializedCondition(rank == 1): + if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"): + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter_via_sort), + wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), + name="scatter_via_sort.rocm", + plevel=9, # use the sequential version by default + ) + return strategy + + +@sort_strategy.register(["rocm"]) +def sort_strategy_cuda(attrs, inputs, out_type, target): + """sort rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sort(topi.cuda.sort), + wrap_topi_schedule(topi.cuda.schedule_sort), + name="sort.rocm", + ) + if can_use_thrust(target, "tvm.contrib.thrust.sort"): + strategy.add_implementation( + wrap_compute_sort(topi.cuda.sort_thrust), + wrap_topi_schedule(topi.cuda.schedule_sort), + name="sort_thrust.cuda", + plevel=15, + ) + return strategy + + +@topk_strategy.register(["rocm"]) +def topk_strategy_cuda(attrs, inputs, out_type, target): + """topk rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_topk(topi.cuda.topk), + wrap_topi_schedule(topi.cuda.schedule_topk), + name="topk.rocm", + ) + + if can_use_thrust(target, "tvm.contrib.thrust.sort"): + strategy.add_implementation( + wrap_compute_topk(topi.cuda.topk_thrust), + wrap_topi_schedule(topi.cuda.schedule_topk), + name="topk_thrust.rocm", + plevel=15, + ) + return strategy diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2d6e1e464ef8..98cb6750408a 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -610,7 +610,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape): ) target = tvm.target.Target.current() - if target and target.kind.name == "cuda" and is_thrust_available(): + # TODO(masahi): Check -libs=thrust option + if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available(): sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32") else: sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0bdab100b429..65d23365dc15 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -221,7 +221,7 @@ def ir(data, data_ex_scan, reduction): with ib.if_scope(scan_axis_size > 0): reduction[tid] = binop( data_ex_scan[tid * scan_axis_size + scan_axis_size - 1], - data[tid, scan_axis_size - 1], + data[tid * scan_axis_size + scan_axis_size - 1], ) with ib.else_scope(): reduction[tid] = 0 @@ -352,7 +352,8 @@ def exclusive_scan( def do_scan(data, output_dtype): target = tvm.target.Target.current() - if target and target.kind.name == "cuda" and is_thrust_available(): + # TODO(masahi): Check -libs=thrust option + if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available(): return scan_thrust( data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop ) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 7295d4c47c3f..df83b57847a0 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -26,6 +26,7 @@ #include #include #include +#include #include #include diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py index c5b6a29d57d5..521c20de6cbd 100644 --- a/tests/python/contrib/test_thrust.py +++ b/tests/python/contrib/test_thrust.py @@ -33,25 +33,30 @@ def test_stable_sort_by_key(): keys_out, values_out = stable_sort_by_key_thrust(keys, values) - ctx = tvm.gpu(0) - target = "cuda" - s = te.create_schedule([keys_out.op, values_out.op]) - f = tvm.build(s, [keys, values, keys_out, values_out], target) - - keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) - values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) - keys_np_out = np.zeros(keys_np.shape, np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) - keys_in = tvm.nd.array(keys_np, ctx) - values_in = tvm.nd.array(values_np, ctx) - keys_out = tvm.nd.array(keys_np_out, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(keys_in, values_in, keys_out, values_out) - - ref_keys_out = np.sort(keys_np) - ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) - tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + for target in ["cuda", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue + + target += " -libs=thrust" + ctx = tvm.context(target, 0) + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) + + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) def test_exclusive_scan(): @@ -59,35 +64,41 @@ def test_exclusive_scan(): print("skip because thrust is not enabled...") return - for ishape in [(10,), (10, 10), (10, 10, 10)]: - values = te.placeholder(ishape, name="values", dtype="int32") + for target in ["cuda", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue - with tvm.target.Target("cuda"): - scan, reduction = exclusive_scan(values, return_reduction=True) - s = schedule_scan([scan, reduction]) + target += " -libs=thrust" + for ishape in [(10,), (10, 10), (10, 10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") - ctx = tvm.gpu(0) - f = tvm.build(s, [values, scan, reduction], "cuda") + with tvm.target.Target(target): + scan, reduction = exclusive_scan(values, return_reduction=True) + s = schedule_scan([scan, reduction]) - values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) + ctx = tvm.context(target, 0) + f = tvm.build(s, [values, scan, reduction], target) - if len(ishape) == 1: - reduction_shape = () - else: - reduction_shape = ishape[:-1] + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) - reduction_np_out = np.zeros(reduction_shape, np.int32) + if len(ishape) == 1: + reduction_shape = () + else: + reduction_shape = ishape[:-1] - values_in = tvm.nd.array(values_np, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - reduction_out = tvm.nd.array(reduction_np_out, ctx) - f(values_in, values_out, reduction_out) + reduction_np_out = np.zeros(reduction_shape, np.int32) - ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) - ref_reduction_out = np.sum(values_np, axis=-1) - tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + reduction_out = tvm.nd.array(reduction_np_out, ctx) + f(values_in, values_out, reduction_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + ref_reduction_out = np.sum(values_np, axis=-1) + tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) def test_inclusive_scan(): @@ -97,24 +108,30 @@ def test_inclusive_scan(): out_dtype = "int64" - for ishape in [(10,), (10, 10)]: - values = te.placeholder(ishape, name="values", dtype="int32") + for target in ["cuda", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue - with tvm.target.Target("cuda"): - scan = scan_thrust(values, out_dtype, exclusive=False) - s = tvm.te.create_schedule([scan.op]) + target += " -libs=thrust" + for ishape in [(10,), (10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") - ctx = tvm.gpu(0) - f = tvm.build(s, [values, scan], "cuda") + with tvm.target.Target(target): + scan = scan_thrust(values, out_dtype, exclusive=False) + s = tvm.te.create_schedule([scan.op]) - values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) - values_np_out = np.zeros(values_np.shape, out_dtype) - values_in = tvm.nd.array(values_np, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(values_in, values_out) + ctx = tvm.context(target, 0) + f = tvm.build(s, [values, scan], target) - ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, out_dtype) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(values_in, values_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) if __name__ == "__main__":