Skip to content

Commit

Permalink
Use nrt api to allocate meminfo object
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed May 14, 2024
1 parent 574daab commit 0364e76
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
32 changes: 22 additions & 10 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,23 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
bool value_is_float,
int64_t value,
const DPCTLSyclQueueRef qref);
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
PyObject *ndarrobj,
void *data,
npy_intp nitems,
npy_intp itemsize,
DPCTLSyclQueueRef qref);
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
static NRT_MemInfo *DPEXRT_MemInfo_alloc(NRT_api_functions *nrt,
npy_intp size,
size_t usm_type,
const DPCTLSyclQueueRef qref);
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info);
static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
int ndim,
PyArray_Descr *descr);

static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
PyObject *obj,
usmarystruct_t *arystruct);
static PyObject *
DPEXRT_sycl_usm_ndarray_to_python_acqref(usmarystruct_t *arystruct,
Expand Down Expand Up @@ -336,6 +339,11 @@ NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type)
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info)
{
MemInfoDtorInfo *mi_dtor_info = NULL;
// Warning: we are destructing sycl memory. MI destructor is called
// separately by numba.
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: Call to "
"usmndarray_meminfo_dtor at %s, line %d\n",
__FILE__, __LINE__));

// Sanity-check to make sure the mi_dtor_info is an actual pointer.
if (!(mi_dtor_info = (MemInfoDtorInfo *)info)) {
Expand Down Expand Up @@ -416,7 +424,8 @@ static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner)
* of the dpnp.ndarray was allocated.
* @return {return} A new NRT_MemInfo object
*/
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(NRT_api_functions *nrt,
PyObject *ndarrobj,
void *data,
npy_intp nitems,
npy_intp itemsize,
Expand All @@ -427,8 +436,9 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
MemInfoDtorInfo *midtor_info = NULL;
DPCTLSyclContextRef cref = NULL;

// Allocate a new NRT_MemInfo object
if (!(mi = (NRT_MemInfo *)malloc(sizeof(NRT_MemInfo)))) {
// Allocate a new NRT_MemInfo object. By passing 0 we are just allocating
// MemInfo and not the `data` that the MemInfo object manages.
if (!(mi = (NRT_MemInfo *)nrt->allocate(0))) {
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: Could not allocate a new NRT_MemInfo "
"object at %s, line %d\n",
Expand Down Expand Up @@ -505,7 +515,8 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
* @return {return} A new NRT_MemInfo object, NULL if no NRT_MemInfo
* object could be created.
*/
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
static NRT_MemInfo *DPEXRT_MemInfo_alloc(NRT_api_functions *nrt,
npy_intp size,
size_t usm_type,
const DPCTLSyclQueueRef qref)
{
Expand All @@ -517,7 +528,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
"DPEXRT-DEBUG: Inside DPEXRT_MemInfo_alloc %s, line %d\n", __FILE__,
__LINE__));
// Allocate a new NRT_MemInfo object
if (!(mi = (NRT_MemInfo *)malloc(sizeof(NRT_MemInfo)))) {
if (!(mi = (NRT_MemInfo *)nrt->allocate(0))) {
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: Could not allocate a new NRT_MemInfo object.\n"));
goto error;
Expand Down Expand Up @@ -795,7 +806,8 @@ static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim)
* instance of a dpnp.ndarray
* @return {return} Error code representing success (0) or failure (-1).
*/
static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
static int DPEXRT_sycl_usm_ndarray_from_python(NRT_api_functions *nrt,
PyObject *obj,
usmarystruct_t *arystruct)
{
struct PyUSMArrayObject *arrayobj = NULL;
Expand Down Expand Up @@ -842,7 +854,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
}

if (!(arystruct->meminfo = NRT_MemInfo_new_from_usmndarray(
obj, data, nitems, itemsize, qref)))
nrt, obj, data, nitems, itemsize, qref)))
{
DPEXRT_DEBUG(drt_debug_print(
"DPEXRT-ERROR: NRT_MemInfo_new_from_usmndarray failed "
Expand Down
12 changes: 8 additions & 4 deletions numba_dpex/core/runtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, queue_ref):
mod = builder.module
u64 = llvmir.IntType(64)
fnty = llvmir.FunctionType(
cgutils.voidptr_t, [cgutils.intp_t, u64, cgutils.voidptr_t]
cgutils.voidptr_t,
[cgutils.voidptr_t, cgutils.intp_t, u64, cgutils.voidptr_t],
)
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_alloc")
fn.return_value.add_attribute("noalias")
nrt_api = self._context.nrt.get_nrt_api(builder)

ret = builder.call(fn, [size, usm_type, queue_ref])
ret = builder.call(fn, [nrt_api, size, usm_type, queue_ref])

return ret

Expand Down Expand Up @@ -168,13 +170,15 @@ def arraystruct_from_python(self, pyapi, obj, ptr):
"""
fnty = llvmir.FunctionType(
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
llvmir.IntType(32), [pyapi.voidptr, pyapi.pyobj, pyapi.voidptr]
)
nrt_api = self._context.nrt.get_nrt_api(pyapi.builder)
fn = pyapi._get_function(fnty, "DPEXRT_sycl_usm_ndarray_from_python")
fn.args[0].add_attribute("nocapture")
fn.args[1].add_attribute("nocapture")
fn.args[2].add_attribute("nocapture")

self.error = pyapi.builder.call(fn, (obj, ptr))
self.error = pyapi.builder.call(fn, (nrt_api, obj, ptr))

return self.error

Expand Down

0 comments on commit 0364e76

Please sign in to comment.