Skip to content

Commit

Permalink
Merge pull request #1860 from PietroGhg/pietro/fill
Browse files Browse the repository at this point in the history
[NATIVECPU] Fix pointer arithmetic in USMfill
  • Loading branch information
omarahmed1111 authored Aug 8, 2024
2 parents ab9baf5 + 8fb6824 commit 83f7ad9
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions source/adapters/native_cpu/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(

UR_ASSERT(ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
UR_ASSERT(pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
UR_ASSERT(size % patternSize == 0 || patternSize > size,
UR_RESULT_ERROR_INVALID_SIZE);
UR_ASSERT(patternSize != 0, UR_RESULT_ERROR_INVALID_SIZE)
UR_ASSERT(size != 0, UR_RESULT_ERROR_INVALID_SIZE)
UR_ASSERT(patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
UR_ASSERT(size % patternSize == 0, UR_RESULT_ERROR_INVALID_SIZE)
// TODO: add check for allocation size once the query is supported

switch (patternSize) {
case 1:
Expand All @@ -522,33 +525,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
const auto pattern = *static_cast<const uint16_t *>(pPattern);
auto *start = reinterpret_cast<uint16_t *>(ptr);
auto *end =
reinterpret_cast<uint16_t *>(reinterpret_cast<uint16_t *>(ptr) + size);
reinterpret_cast<uint16_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
std::fill(start, end, pattern);
break;
}
case 4: {
const auto pattern = *static_cast<const uint32_t *>(pPattern);
auto *start = reinterpret_cast<uint32_t *>(ptr);
auto *end =
reinterpret_cast<uint32_t *>(reinterpret_cast<uint32_t *>(ptr) + size);
reinterpret_cast<uint32_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
std::fill(start, end, pattern);
break;
}
case 8: {
const auto pattern = *static_cast<const uint64_t *>(pPattern);
auto *start = reinterpret_cast<uint64_t *>(ptr);
auto *end =
reinterpret_cast<uint64_t *>(reinterpret_cast<uint64_t *>(ptr) + size);
reinterpret_cast<uint64_t *>(reinterpret_cast<uint8_t *>(ptr) + size);
std::fill(start, end, pattern);
break;
}
default:
for (unsigned int step{0}; step < size; ++step) {
auto *dest = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(ptr) +
step * patternSize);
default: {
for (unsigned int step{0}; step < size; step += patternSize) {
auto *dest =
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(ptr) + step);
memcpy(dest, pPattern, patternSize);
}
}
}
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -583,7 +587,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
std::ignore = phEventWaitList;
std::ignore = phEvent;

DIE_NO_IMPLEMENTATION;
// TODO: properly implement USM prefetch
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
Expand All @@ -595,7 +600,8 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
std::ignore = advice;
std::ignore = phEvent;

DIE_NO_IMPLEMENTATION;
// TODO: properly implement USM advise
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(
Expand Down

0 comments on commit 83f7ad9

Please sign in to comment.