diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index f048aed43bf5..4aadfb0c083b 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1352,14 +1352,12 @@ class CUDAStream { cudaStream_t stream_; public: - CUDAStream() { - dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); - } - ~CUDAStream() { - dh::safe_cuda(cudaStreamDestroy(stream_)); - } + CUDAStream() { dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } + ~CUDAStream() { dh::safe_cuda(cudaStreamDestroy(stream_)); } + + [[nodiscard]] CUDAStreamView View() const { return CUDAStreamView{stream_}; } + [[nodiscard]] cudaStream_t Handle() const { return stream_; } - CUDAStreamView View() const { return CUDAStreamView{stream_}; } void Sync() { this->View().Sync(); } }; diff --git a/src/data/array_interface.cu b/src/data/array_interface.cu index b1a80251ecc4..28d8945c2ac3 100644 --- a/src/data/array_interface.cu +++ b/src/data/array_interface.cu @@ -1,11 +1,15 @@ -/*! - * Copyright 2021 by Contributors +/** + * Copyright 2021-2023, XGBoost Contributors */ +#include // for int64_t + #include "../common/common.h" +#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent #include "array_interface.h" +#include "xgboost/logging.h" namespace xgboost { -void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { +void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) { switch (stream) { case 0: /** @@ -22,8 +26,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { break; case 2: // default per-thread stream - default: - dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast(stream))); + default: { + dh::CUDAEvent e; + e.Record(dh::CUDAStreamView{reinterpret_cast(stream)}); + dh::DefaultStream().Wait(e); + } } } diff --git a/tests/cpp/data/test_array_interface.cu b/tests/cpp/data/test_array_interface.cu index c8e07852534b..00b996fb9ffb 100644 --- a/tests/cpp/data/test_array_interface.cu +++ b/tests/cpp/data/test_array_interface.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2021 by Contributors +/** + * Copyright 2021-2023, XGBoost Contributors */ #include #include @@ -22,22 +22,19 @@ TEST(ArrayInterface, Stream) { HostDeviceVector storage; auto arr_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); - cudaStream_t stream; - cudaStreamCreate(&stream); + dh::CUDAStream stream; - auto j_arr =Json::Load(StringView{arr_str}); - j_arr["stream"] = Integer(reinterpret_cast(stream)); + auto j_arr = Json::Load(StringView{arr_str}); + j_arr["stream"] = Integer(reinterpret_cast(stream.Handle())); Json::Dump(j_arr, &arr_str); dh::caching_device_vector out(1, 0); - uint64_t dur = 1e9; - dh::LaunchKernel{1, 1, 0, stream}(SleepForTest, out.data().get(), dur); + std::uint64_t dur = 1e9; + dh::LaunchKernel{1, 1, 0, stream.View()}(SleepForTest, out.data().get(), dur); ArrayInterface<2> arr(arr_str); auto t = out[0]; CHECK_GE(t, dur); - - cudaStreamDestroy(stream); } TEST(ArrayInterface, Ptr) {