Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
  • Loading branch information
zhiweij1 committed Oct 14, 2024
1 parent cde4bae commit b11e199
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions clang/runtime/dpct-rt/include/dpct/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ static buffer_t get_buffer(const void *ptr) {
}

/// A wrapper class contains an accessor and an offset.
template <typename dataT,
template <typename PtrT,
sycl::access_mode accessMode = sycl::access_mode::read_write>
class access_wrapper {
sycl::accessor<byte_t, 1, accessMode> accessor;
Expand All @@ -931,11 +931,17 @@ class access_wrapper {
auto alloc = detail::mem_mgr::instance().translate_ptr(ptr);
offset = (byte_t *)ptr - alloc.alloc_ptr;
}
template <typename U = PtrT>
access_wrapper(
PtrT ptr, sycl::handler &cgh,
typename std::enable_if_t<!std::is_same_v<
std::remove_cv_t<std::remove_reference_t<U>>, void *>> * = 0)
: access_wrapper((const void *)ptr, cgh) {}

/// Get the device pointer.
///
/// \returns a device pointer with offset.
dataT get_raw_pointer() const { return (dataT)(&accessor[0] + offset); }
PtrT get_raw_pointer() const { return (PtrT)(&accessor[0] + offset); }
};

/// Get the accessor for memory pointed by \p ptr.
Expand All @@ -944,12 +950,17 @@ class access_wrapper {
/// If NULL is passed as an argument, an exception will be thrown.
/// \param cgh The command group handler.
/// \returns an accessor.
template <sycl::access_mode accessMode = sycl::access_mode::read_write>
static sycl::accessor<byte_t, 1, accessMode>
get_access(const void *ptr, sycl::handler &cgh) {
template <typename T,
sycl::access_mode accessMode = sycl::access_mode::read_write>
static auto get_access(const T *ptr, sycl::handler &cgh) {
if (ptr) {
auto alloc = detail::mem_mgr::instance().translate_ptr(ptr);
return alloc.buffer.get_access<accessMode>(cgh);
if constexpr (std::is_same_v<std::remove_reference_t<T>, void>)
return alloc.buffer.template get_access<accessMode>(cgh);
else
return alloc.buffer
.template reinterpret<T>(sycl::range<1>(alloc.size / sizeof(T)))
.template get_access<accessMode>(cgh);
} else {
throw std::runtime_error(
"NULL pointer argument in get_access function is invalid");
Expand Down

0 comments on commit b11e199

Please sign in to comment.