diff --git a/sycl/source/detail/usm/usm_impl.cpp b/sycl/source/detail/usm/usm_impl.cpp index ad120734db683..d6acee67293cf 100644 --- a/sycl/source/detail/usm/usm_impl.cpp +++ b/sycl/source/detail/usm/usm_impl.cpp @@ -27,6 +27,8 @@ namespace usm { void *alignedAllocHost(size_t Alignment, size_t Size, const context &Ctxt, alloc Kind) { void *RetVal = nullptr; + if (Size == 0) + return nullptr; if (Ctxt.is_host()) { if (!Alignment) { // worst case default @@ -72,6 +74,8 @@ void *alignedAllocHost(size_t Alignment, size_t Size, const context &Ctxt, void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt, const device &Dev, alloc Kind) { void *RetVal = nullptr; + if (Size == 0) + return nullptr; if (Ctxt.is_host()) { if (Kind == alloc::unknown) { RetVal = nullptr; @@ -126,6 +130,8 @@ void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt, } void free(void *Ptr, const context &Ctxt) { + if (Ptr == nullptr) + return; if (Ctxt.is_host()) { // need to use alignedFree here for Windows detail::OSUtil::alignedFree(Ptr); diff --git a/sycl/test/regression/usm_malloc_shared.cpp b/sycl/test/regression/usm_malloc_shared.cpp new file mode 100644 index 0000000000000..c078d225791f8 --- /dev/null +++ b/sycl/test/regression/usm_malloc_shared.cpp @@ -0,0 +1,49 @@ +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: env SYCL_DEVICE_TYPE=HOST %t.out +// RUN: %CPU_RUN_PLACEHOLDER %t.out + +// This test checks if users will successfully allocate 160, 0, and -16 bytes of +// shared memory, and also test user can call free() without worrying about +// nullptr or invalid memory descriptor returned from malloc. + +#include +#include +#include +using namespace cl::sycl; + +int main(int argc, char *argv[]) { + auto exception_handler = [](cl::sycl::exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (cl::sycl::exception const &e) { + std::cout << "Caught asynchronous SYCL " + "exception:\n" + << e.what() << std::endl; + } + } + }; + + queue myQueue(default_selector{}, exception_handler); + std::cout << "Device: " << myQueue.get_device().get_info() + << std::endl; + + double *ia = (double *)malloc_shared(160, myQueue); + double *ja = (double *)malloc_shared(0, myQueue); + double *result = (double *)malloc_shared(-16, myQueue); + + assert(ia != nullptr); + assert(ja == nullptr); + assert(result == nullptr); + + std::cout << "ia : " << ia << " ja: " << ja << " result : " << result + << std::endl; + + // followings should not throw CL_INVALID_VALUE + cl::sycl::free(ia, myQueue); + cl::sycl::free(nullptr, myQueue); + cl::sycl::free(ja, myQueue); + cl::sycl::free(result, myQueue); + + return 0; +}