Skip to content

Commit

Permalink
[ROCM] Add Thrust support (apache#7458)
Browse files Browse the repository at this point in the history
* enable rocm thrust, confrimed to work on sort and scan

* add rocm argsort strategy

* Abort if CXX is not hipcc

* add more strategy

* add missing import

* fix lint

* show supported data type in err msg

* try remove rocthrust

* add missing include for rocthrust

* more minor change

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
  • Loading branch information
2 people authored and Lokiiiiii committed Mar 1, 2021
1 parent 64afc40 commit de84e5f
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 58 deletions.
17 changes: 17 additions & 0 deletions cmake/modules/ROCM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ if(USE_ROCM)
list(APPEND RUNTIME_SRCS ${ROCBLAS_CONTRIB_SRCS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY})
endif(USE_ROCBLAS)

if(USE_THRUST)
message(STATUS "Build with rocThrust support")
# We need to override CXX to hipcc. This is required by rocthrust
if (${CMAKE_CXX_COMPILER} MATCHES "hipcc$")
message(STATUS "Using hipcc compiler to compile rocthrust code.")
else()
message(FATAL_ERROR "Set CXX=hipcc to compile rocthrust code.")
endif()

find_package(rocprim REQUIRED)
find_package(rocthrust REQUIRED)
set_source_files_properties(src/runtime/contrib/thrust/thrust.cu PROPERTIES LANGUAGE CXX)
list(APPEND RUNTIME_SRCS src/runtime/contrib/thrust/thrust.cu)
list(APPEND TVM_RUNTIME_LINKER_LIBS roc::rocthrust)
endif(USE_THRUST)

else(USE_ROCM)
list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc)
endif(USE_ROCM)
92 changes: 92 additions & 0 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm._ffi import get_global_func
from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule
Expand Down Expand Up @@ -219,3 +221,93 @@ def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
plevel=12,
)
return strategy


def can_use_thrust(target, func_name):
return (
target.kind.name == "rocm"
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)


@argsort_strategy.register(["rocm"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort_thrust.rocm",
plevel=15,
)
return strategy


@scatter_strategy.register(["rocm"])
def scatter_cuda(attrs, inputs, out_type, target):
"""scatter rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
wrap_topi_schedule(topi.cuda.schedule_scatter),
name="scatter.rocm",
plevel=10,
)

rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.rocm",
plevel=9, # use the sequential version by default
)
return strategy


@sort_strategy.register(["rocm"])
def sort_strategy_cuda(attrs, inputs, out_type, target):
"""sort rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort),
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort_thrust),
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort_thrust.cuda",
plevel=15,
)
return strategy


@topk_strategy.register(["rocm"])
def topk_strategy_cuda(attrs, inputs, out_type, target):
"""topk rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.rocm",
)

if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk_thrust.rocm",
plevel=15,
)
return strategy
3 changes: 2 additions & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
)

target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def ir(data, data_ex_scan, reduction):
with ib.if_scope(scan_axis_size > 0):
reduction[tid] = binop(
data_ex_scan[tid * scan_axis_size + scan_axis_size - 1],
data[tid, scan_axis_size - 1],
data[tid * scan_axis_size + scan_axis_size - 1],
)
with ib.else_scope():
reduction[tid] = 0
Expand Down Expand Up @@ -352,7 +352,8 @@ def exclusive_scan(

def do_scan(data, output_dtype):
target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
Expand Down
1 change: 1 addition & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <thrust/sort.h>
#include <thrust/gather.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
Expand Down
127 changes: 72 additions & 55 deletions tests/python/contrib/test_thrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,61 +33,72 @@ def test_stable_sort_by_key():

keys_out, values_out = stable_sort_by_key_thrust(keys, values)

ctx = tvm.gpu(0)
target = "cuda"
s = te.create_schedule([keys_out.op, values_out.op])
f = tvm.build(s, [keys, values, keys_out, values_out], target)

keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
keys_np_out = np.zeros(keys_np.shape, np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
keys_in = tvm.nd.array(keys_np, ctx)
values_in = tvm.nd.array(values_np, ctx)
keys_out = tvm.nd.array(keys_np_out, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(keys_in, values_in, keys_out, values_out)

ref_keys_out = np.sort(keys_np)
ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)])
tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
for target in ["cuda", "rocm"]:
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
continue

target += " -libs=thrust"
ctx = tvm.context(target, 0)
s = te.create_schedule([keys_out.op, values_out.op])
f = tvm.build(s, [keys, values, keys_out, values_out], target)

keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
keys_np_out = np.zeros(keys_np.shape, np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
keys_in = tvm.nd.array(keys_np, ctx)
values_in = tvm.nd.array(values_np, ctx)
keys_out = tvm.nd.array(keys_np_out, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(keys_in, values_in, keys_out, values_out)

ref_keys_out = np.sort(keys_np)
ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)])
tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


def test_exclusive_scan():
if not is_thrust_available():
print("skip because thrust is not enabled...")
return

for ishape in [(10,), (10, 10), (10, 10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")
for target in ["cuda", "rocm"]:
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
continue

with tvm.target.Target("cuda"):
scan, reduction = exclusive_scan(values, return_reduction=True)
s = schedule_scan([scan, reduction])
target += " -libs=thrust"
for ishape in [(10,), (10, 10), (10, 10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")

ctx = tvm.gpu(0)
f = tvm.build(s, [values, scan, reduction], "cuda")
with tvm.target.Target(target):
scan, reduction = exclusive_scan(values, return_reduction=True)
s = schedule_scan([scan, reduction])

values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
ctx = tvm.context(target, 0)
f = tvm.build(s, [values, scan, reduction], target)

if len(ishape) == 1:
reduction_shape = ()
else:
reduction_shape = ishape[:-1]
values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)

reduction_np_out = np.zeros(reduction_shape, np.int32)
if len(ishape) == 1:
reduction_shape = ()
else:
reduction_shape = ishape[:-1]

values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
reduction_out = tvm.nd.array(reduction_np_out, ctx)
f(values_in, values_out, reduction_out)
reduction_np_out = np.zeros(reduction_shape, np.int32)

ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
ref_reduction_out = np.sum(values_np, axis=-1)
tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5)
values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
reduction_out = tvm.nd.array(reduction_np_out, ctx)
f(values_in, values_out, reduction_out)

ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
ref_reduction_out = np.sum(values_np, axis=-1)
tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5)


def test_inclusive_scan():
Expand All @@ -97,24 +108,30 @@ def test_inclusive_scan():

out_dtype = "int64"

for ishape in [(10,), (10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")
for target in ["cuda", "rocm"]:
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
continue

with tvm.target.Target("cuda"):
scan = scan_thrust(values, out_dtype, exclusive=False)
s = tvm.te.create_schedule([scan.op])
target += " -libs=thrust"
for ishape in [(10,), (10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")

ctx = tvm.gpu(0)
f = tvm.build(s, [values, scan], "cuda")
with tvm.target.Target(target):
scan = scan_thrust(values, out_dtype, exclusive=False)
s = tvm.te.create_schedule([scan.op])

values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, out_dtype)
values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(values_in, values_out)
ctx = tvm.context(target, 0)
f = tvm.build(s, [values, scan], target)

ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, out_dtype)
values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(values_in, values_out)

ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


if __name__ == "__main__":
Expand Down

0 comments on commit de84e5f

Please sign in to comment.