Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM] Add Thrust support #7458

Merged
merged 10 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cmake/modules/ROCM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ if(USE_ROCM)
list(APPEND RUNTIME_SRCS ${ROCBLAS_CONTRIB_SRCS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY})
endif(USE_ROCBLAS)

if(USE_THRUST)
message(STATUS "Build with rocThrust support")
# We need to override CXX to hipcc. This is required by rocthrust
if (${CMAKE_CXX_COMPILER} MATCHES "hipcc$")
message(STATUS "Using hipcc compiler to compile rocthrust code.")
else()
message(FATAL_ERROR "Set CXX=hipcc to compile rocthrust code.")
endif()

find_package(rocprim REQUIRED)
find_package(rocthrust REQUIRED)
set_source_files_properties(src/runtime/contrib/thrust/thrust.cu PROPERTIES LANGUAGE CXX)
list(APPEND RUNTIME_SRCS src/runtime/contrib/thrust/thrust.cu)
list(APPEND TVM_RUNTIME_LINKER_LIBS roc::rocthrust)
endif(USE_THRUST)

else(USE_ROCM)
list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc)
endif(USE_ROCM)
92 changes: 92 additions & 0 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm._ffi import get_global_func
from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule
Expand Down Expand Up @@ -219,3 +221,93 @@ def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
plevel=12,
)
return strategy


def can_use_thrust(target, func_name):
return (
target.kind.name == "rocm"
and "thrust" in target.libs
and get_global_func(func_name, allow_missing=True)
)
masahi marked this conversation as resolved.
Show resolved Hide resolved


@argsort_strategy.register(["rocm"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort_thrust.rocm",
plevel=15,
)
return strategy


@scatter_strategy.register(["rocm"])
def scatter_cuda(attrs, inputs, out_type, target):
"""scatter rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
wrap_topi_schedule(topi.cuda.schedule_scatter),
name="scatter.rocm",
plevel=10,
)

rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.rocm",
plevel=9, # use the sequential version by default
)
return strategy


@sort_strategy.register(["rocm"])
def sort_strategy_cuda(attrs, inputs, out_type, target):
"""sort rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort),
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort.rocm",
)
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort_thrust),
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort_thrust.cuda",
plevel=15,
)
return strategy


@topk_strategy.register(["rocm"])
def topk_strategy_cuda(attrs, inputs, out_type, target):
"""topk rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.rocm",
)

if can_use_thrust(target, "tvm.contrib.thrust.sort"):
strategy.add_implementation(
wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk_thrust.rocm",
plevel=15,
)
return strategy
3 changes: 2 additions & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
)

target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def ir(data, data_ex_scan, reduction):
with ib.if_scope(scan_axis_size > 0):
reduction[tid] = binop(
data_ex_scan[tid * scan_axis_size + scan_axis_size - 1],
data[tid, scan_axis_size - 1],
data[tid * scan_axis_size + scan_axis_size - 1],
)
with ib.else_scope():
reduction[tid] = 0
Expand Down Expand Up @@ -352,7 +352,8 @@ def exclusive_scan(

def do_scan(data, output_dtype):
target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
# TODO(masahi): Check -libs=thrust option
if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available():
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
Expand Down
1 change: 1 addition & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <thrust/sort.h>
#include <thrust/gather.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
Expand Down
127 changes: 72 additions & 55 deletions tests/python/contrib/test_thrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,61 +33,72 @@ def test_stable_sort_by_key():

keys_out, values_out = stable_sort_by_key_thrust(keys, values)

ctx = tvm.gpu(0)
target = "cuda"
s = te.create_schedule([keys_out.op, values_out.op])
f = tvm.build(s, [keys, values, keys_out, values_out], target)

keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
keys_np_out = np.zeros(keys_np.shape, np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
keys_in = tvm.nd.array(keys_np, ctx)
values_in = tvm.nd.array(values_np, ctx)
keys_out = tvm.nd.array(keys_np_out, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(keys_in, values_in, keys_out, values_out)

ref_keys_out = np.sort(keys_np)
ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)])
tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
for target in ["cuda", "rocm"]:
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
continue

target += " -libs=thrust"
ctx = tvm.context(target, 0)
s = te.create_schedule([keys_out.op, values_out.op])
f = tvm.build(s, [keys, values, keys_out, values_out], target)

keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
keys_np_out = np.zeros(keys_np.shape, np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
keys_in = tvm.nd.array(keys_np, ctx)
values_in = tvm.nd.array(values_np, ctx)
keys_out = tvm.nd.array(keys_np_out, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(keys_in, values_in, keys_out, values_out)

ref_keys_out = np.sort(keys_np)
ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)])
tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


def test_exclusive_scan():
if not is_thrust_available():
print("skip because thrust is not enabled...")
return

for ishape in [(10,), (10, 10), (10, 10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")
for target in ["cuda", "rocm"]:
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
continue

with tvm.target.Target("cuda"):
scan, reduction = exclusive_scan(values, return_reduction=True)
s = schedule_scan([scan, reduction])
target += " -libs=thrust"
for ishape in [(10,), (10, 10), (10, 10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")

ctx = tvm.gpu(0)
f = tvm.build(s, [values, scan, reduction], "cuda")
with tvm.target.Target(target):
scan, reduction = exclusive_scan(values, return_reduction=True)
s = schedule_scan([scan, reduction])

values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
ctx = tvm.context(target, 0)
f = tvm.build(s, [values, scan, reduction], target)

if len(ishape) == 1:
reduction_shape = ()
else:
reduction_shape = ishape[:-1]
values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)

reduction_np_out = np.zeros(reduction_shape, np.int32)
if len(ishape) == 1:
reduction_shape = ()
else:
reduction_shape = ishape[:-1]

values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
reduction_out = tvm.nd.array(reduction_np_out, ctx)
f(values_in, values_out, reduction_out)
reduction_np_out = np.zeros(reduction_shape, np.int32)

ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
ref_reduction_out = np.sum(values_np, axis=-1)
tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5)
values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
reduction_out = tvm.nd.array(reduction_np_out, ctx)
f(values_in, values_out, reduction_out)

ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
ref_reduction_out = np.sum(values_np, axis=-1)
tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5)


def test_inclusive_scan():
Expand All @@ -97,24 +108,30 @@ def test_inclusive_scan():

out_dtype = "int64"

for ishape in [(10,), (10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")
for target in ["cuda", "rocm"]:
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
continue

with tvm.target.Target("cuda"):
scan = scan_thrust(values, out_dtype, exclusive=False)
s = tvm.te.create_schedule([scan.op])
target += " -libs=thrust"
for ishape in [(10,), (10, 10)]:
values = te.placeholder(ishape, name="values", dtype="int32")

ctx = tvm.gpu(0)
f = tvm.build(s, [values, scan], "cuda")
with tvm.target.Target(target):
scan = scan_thrust(values, out_dtype, exclusive=False)
s = tvm.te.create_schedule([scan.op])

values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, out_dtype)
values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(values_in, values_out)
ctx = tvm.context(target, 0)
f = tvm.build(s, [values, scan], target)

ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
values_np = np.random.randint(0, 10, size=ishape).astype(np.int32)
values_np_out = np.zeros(values_np.shape, out_dtype)
values_in = tvm.nd.array(values_np, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(values_in, values_out)

ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


if __name__ == "__main__":
Expand Down