Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use nrt api to allocate meminfo object #1458

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,23 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
bool value_is_float,
int64_t value,
const DPCTLSyclQueueRef qref);
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
PyObject *ndarrobj,
void *data,
npy_intp nitems,
npy_intp itemsize,
DPCTLSyclQueueRef qref);
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
static NRT_MemInfo *DPEXRT_MemInfo_alloc(NRT_api_functions *nrt,
npy_intp size,
size_t usm_type,
const DPCTLSyclQueueRef qref);
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info);
static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
int ndim,
PyArray_Descr *descr);

static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
PyObject *obj,
usmarystruct_t *arystruct);
static PyObject *
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
Expand Down Expand Up @@ -336,6 +339,11 @@ NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type)
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info)
{
MemInfoDtorInfo *mi_dtor_info = NULL;
// Warning: we are destructing sycl memory. MI destructor is called
// separately by numba.
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: Call to "
"usmndarray_meminfo_dtor at %s, line %d\n",
__FILE__, __LINE__));

// Sanity-check to make sure the mi_dtor_info is an actual pointer.
if (!(mi_dtor_info = (MemInfoDtorInfo *)info)) {
Expand Down Expand Up @@ -416,7 +424,8 @@ static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner)
* of the dpnp.ndarray was allocated.
* @return {return} A new NRT_MemInfo object
*/
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
PyObject *ndarrobj,
void *data,
npy_intp nitems,
npy_intp itemsize,
Expand All @@ -427,8 +436,9 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
MemInfoDtorInfo *midtor_info = NULL;
DPCTLSyclContextRef cref = NULL;

// Allocate a new NRT_MemInfo object
if (!(mi = (NRT_MemInfo *)malloc(sizeof(NRT_MemInfo)))) {
// Allocate a new NRT_MemInfo object. By passing 0 we are just allocating
// MemInfo and not the `data` that the MemInfo object manages.
if (!(mi = (NRT_MemInfo *)nrt->allocate(0))) {
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: Could not allocate a new NRT_MemInfo "
"object at %s, line %d\n",
Expand Down Expand Up @@ -505,7 +515,8 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
* @return {return} A new NRT_MemInfo object, NULL if no NRT_MemInfo
* object could be created.
*/
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
static NRT_MemInfo *DPEXRT_MemInfo_alloc(NRT_api_functions *nrt,
npy_intp size,
size_t usm_type,
const DPCTLSyclQueueRef qref)
{
Expand All @@ -517,7 +528,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
"DPEXRT-DEBUG: Inside DPEXRT_MemInfo_alloc %s, line %d\n", __FILE__,
__LINE__));
// Allocate a new NRT_MemInfo object
if (!(mi = (NRT_MemInfo *)malloc(sizeof(NRT_MemInfo)))) {
if (!(mi = (NRT_MemInfo *)nrt->allocate(0))) {
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: Could not allocate a new NRT_MemInfo object.\n"));
goto error;
Expand Down Expand Up @@ -795,7 +806,8 @@ static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim)
* instance of a dpnp.ndarray
* @return {return} Error code representing success (0) or failure (-1).
*/
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
PyObject *obj,
usmarystruct_t *arystruct)
{
struct PyUSMArrayObject *arrayobj = NULL;
Expand Down Expand Up @@ -842,7 +854,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
}

if (!(arystruct->meminfo = NRT_MemInfo_new_from_usmndarray(
obj, data, nitems, itemsize, qref)))
nrt, obj, data, nitems, itemsize, qref)))
{
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: NRT_MemInfo_new_from_usmndarray failed "
Expand Down
12 changes: 8 additions & 4 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, queue_ref):
mod = builder.module
u64 = llvmir.IntType(64)
fnty = llvmir.FunctionType(
cgutils.voidptr_t, [cgutils.intp_t, u64, cgutils.voidptr_t]
cgutils.voidptr_t,
[cgutils.voidptr_t, cgutils.intp_t, u64, cgutils.voidptr_t],
)
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_alloc")
fn.return_value.add_attribute("noalias")
nrt_api = self._context.nrt.get_nrt_api(builder)

ret = builder.call(fn, [size, usm_type, queue_ref])
ret = builder.call(fn, [nrt_api, size, usm_type, queue_ref])

return ret

Expand Down Expand Up @@ -168,13 +170,15 @@ def arraystruct_from_python(self, pyapi, obj, ptr):

"""
fnty = llvmir.FunctionType(
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
)
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
fn = pyapi._get_function(fnty, "DPEXRT_sycl_usm_ndarray_from_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")
fn.args[2].add_attribute("nocapture")

self.error = pyapi.builder.call(fn, (obj, ptr))
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))

return self.error

Expand Down
Loading