diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/execution_sycl_defs.h b/include/oneapi/dpl/pstl/hetero/dpcpp/execution_sycl_defs.h index efe92a2b226..9d8a66506c9 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/execution_sycl_defs.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/execution_sycl_defs.h @@ -21,6 +21,8 @@ #include "sycl_defs.h" +#include +#include #include namespace oneapi @@ -42,21 +44,66 @@ struct DefaultKernelName; template class device_policy { + // Needed for the copy constructor that rebinds the kernel name + template + friend class device_policy; + + template + static auto + lock_and_forward(T&& t, std::mutex& mtx) + { + ::std::scoped_lock lock{mtx}; + return std::forward(t); + } + public: using kernel_name = KernelName; device_policy() = default; + explicit device_policy(sycl::queue q_) : q(q_) {} + explicit device_policy(sycl::device d_) { q.emplace(d_); } + template - device_policy(const device_policy& other) : q(other.queue()) + device_policy(const device_policy& other) : q(device_policy::lock_and_forward(other.q, other.mtx)) { } - explicit device_policy(sycl::queue q_) : q(q_) {} - explicit device_policy(sycl::device d_) : q(d_) {} - operator sycl::queue() const { return q; } + + device_policy(const device_policy& other) : q(device_policy::lock_and_forward(other.q, other.mtx)) {} + + device_policy(device_policy&& other) : q(device_policy::lock_and_forward(::std::move(other.q), other.mtx)) {} + + device_policy& + operator=(const device_policy& other) + { + if (this != &other) + { + ::std::scoped_lock lock{mtx, other.mtx}; + q = other.q; + } + return *this; + } + + device_policy& + operator=(device_policy&& other) + { + if (this != &other) + { + ::std::scoped_lock lock{mtx, other.mtx}; + q = ::std::move(other.q); + } + return *this; + } + + operator sycl::queue() const { return queue(); } sycl::queue queue() const { - return q; + ::std::scoped_lock lock{mtx}; + if (!q) + { + q.emplace(); + } + return *q; } // For internal use only @@ -77,8 +124,9 @@ class device_policy return ::std::true_type{}; } - private: - sycl::queue q; + protected: + mutable ::std::mutex mtx; + mutable ::std::optional q; }; #if _ONEDPL_FPGA_DEVICE @@ -91,21 +139,31 @@ class fpga_policy : public device_policy public: static constexpr unsigned int unroll_factor = factor; - fpga_policy() - : base(sycl::queue( -# if _ONEDPL_FPGA_EMU - __dpl_sycl::__fpga_emulator_selector() -# else - __dpl_sycl::__fpga_selector() -# endif // _ONEDPL_FPGA_EMU - )) + fpga_policy() = default; + template + fpga_policy(const fpga_policy& other) : base(other.queue()) { } - - template - fpga_policy(const fpga_policy& other) : base(other.queue()){}; explicit fpga_policy(sycl::queue q) : base(q) {} explicit fpga_policy(sycl::device d) : base(d) {} + + operator sycl::queue() const { return queue(); } + sycl::queue + queue() const + { + ::std::scoped_lock lock{this->mtx}; + if (!this->q) + { + this->q.emplace( +# if _ONEDPL_FPGA_EMU + __dpl_sycl::__fpga_emulator_selector() +# else + __dpl_sycl::__fpga_selector() +# endif // _ONEDPL_FPGA_EMU + ); + } + return *this->q; + } }; #endif // _ONEDPL_FPGA_DEVICE