diff --git a/clang/runtime/dpct-rt/include/dpct/memory.hpp b/clang/runtime/dpct-rt/include/dpct/memory.hpp index ce8aa699cc81..be5950d61d72 100644 --- a/clang/runtime/dpct-rt/include/dpct/memory.hpp +++ b/clang/runtime/dpct-rt/include/dpct/memory.hpp @@ -915,7 +915,7 @@ static buffer_t get_buffer(const void *ptr) { } /// A wrapper class contains an accessor and an offset. -template class access_wrapper { sycl::accessor accessor; @@ -931,11 +931,17 @@ class access_wrapper { auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); offset = (byte_t *)ptr - alloc.alloc_ptr; } + template + access_wrapper( + PtrT ptr, sycl::handler &cgh, + typename std::enable_if_t>, void *>> * = 0) + : access_wrapper((const void *)ptr, cgh) {} /// Get the device pointer. /// /// \returns a device pointer with offset. - dataT get_raw_pointer() const { return (dataT)(&accessor[0] + offset); } + PtrT get_raw_pointer() const { return (PtrT)(&accessor[0] + offset); } }; /// Get the accessor for memory pointed by \p ptr. @@ -944,12 +950,17 @@ class access_wrapper { /// If NULL is passed as an argument, an exception will be thrown. /// \param cgh The command group handler. /// \returns an accessor. -template -static sycl::accessor -get_access(const void *ptr, sycl::handler &cgh) { +template +static auto get_access(const T *ptr, sycl::handler &cgh) { if (ptr) { auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); - return alloc.buffer.get_access(cgh); + if constexpr (std::is_same_v, void>) + return alloc.buffer.template get_access(cgh); + else + return alloc.buffer + .template reinterpret(sycl::range<1>(alloc.size / sizeof(T))) + .template get_access(cgh); } else { throw std::runtime_error( "NULL pointer argument in get_access function is invalid");