From e6605c012c511d323f79b8c4717159c20ebecb33 Mon Sep 17 00:00:00 2001 From: "akmkhale@ansatnuc04" Date: Mon, 20 Feb 2023 17:15:02 -0600 Subject: [PATCH] Implementation of dpnp.zeros() and dpnp.ones() interface. It utilzes the overloading construct from numba, also has unit-tests. --- numba_dpex/core/runtime/_dpexrt_python.c | 174 ++++++++++++++++-- numba_dpex/core/runtime/context.py | 53 +++++- numba_dpex/dpnp_iface/arrayobj.py | 162 +++++++++++++++- .../tests/dpjit_tests/dpnp/test_dpnp_ones.py | 55 ++++++ .../tests/dpjit_tests/dpnp/test_dpnp_zeros.py | 55 ++++++ 5 files changed, 476 insertions(+), 23 deletions(-) create mode 100644 numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py create mode 100644 numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py diff --git a/numba_dpex/core/runtime/_dpexrt_python.c b/numba_dpex/core/runtime/_dpexrt_python.c index 0ebf6c98fb..fc13f2954e 100644 --- a/numba_dpex/core/runtime/_dpexrt_python.c +++ b/numba_dpex/core/runtime/_dpexrt_python.c @@ -424,6 +424,123 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device) return NULL; } +/** + * @brief Interface for the core.runtime.context.DpexRTContext.meminfo_alloc. + * This function takes an allocated memory as NRT_MemInfo and fills it with + * the value specified by `value`. + * + * @param mi An NRT_MemInfo object, should be found from memory + * allocation. + * @param itemsize The itemsize, the size of each item in the array. + * @param is_float Flag to specify if the data being float or not. + * @param value The value to be used to fill an array. + * @param device The device on which the memory was allocated. + * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo + * object could be created. + */ +static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi, + size_t itemsize, + bool is_float, + uint8_t value, + const char *device) +{ + DPCTLSyclDeviceSelectorRef dselector = NULL; + DPCTLSyclDeviceRef dref = NULL; + DPCTLSyclQueueRef qref = NULL; + DPCTLSyclEventRef eref = NULL; + size_t count = 0, size = 0, exp = 0; + + size = mi->size; + while (itemsize >>= 1) + exp++; + count = (unsigned int)(size >> exp); + + NRT_Debug(nrt_debug_print( + "DPEXRT-DEBUG: mi->size = %u, itemsize = %u, count = %u, " + "value = %u, Inside DPEXRT_MemInfo_fill %s, line %d\n", + mi->size, itemsize << exp, count, value, __FILE__, __LINE__)); + + if (mi->data == NULL) { + NRT_Debug(nrt_debug_print("DPEXRT-DEBUG: mi->data is NULL, " + "Inside DPEXRT_MemInfo_fill %s, line %d\n", + __FILE__, __LINE__)); + goto error; + } + + if (!(dselector = DPCTLFilterSelector_Create(device))) { + NRT_Debug(nrt_debug_print( + "DPEXRT-ERROR: Could not create a sycl::device_selector from " + "filter string: %s at %s %d.\n", + device, __FILE__, __LINE__)); + goto error; + } + + if (!(dref = DPCTLDevice_CreateFromSelector(dselector))) + goto error; + + if (!(qref = DPCTLQueue_CreateForDevice(dref, NULL, 0))) + goto error; + + DPCTLDeviceSelector_Delete(dselector); + DPCTLDevice_Delete(dref); + + switch (exp) { + case 3: + { + uint64_t value_assign = (uint64_t)value; + if (is_float) { + double const_val = (double)value; + // To stop warning: dereferencing type-punned pointer + // will break strict-aliasing rules [-Wstrict-aliasing] + double *p = &const_val; + value_assign = *((uint64_t *)(p)); + } + if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count))) + goto error; + break; + } + case 2: + { + uint32_t value_assign = (uint32_t)value; + if (is_float) { + float const_val = (float)value; + // To stop warning: dereferencing type-punned pointer + // will break strict-aliasing rules [-Wstrict-aliasing] + float *p = &const_val; + value_assign = *((uint32_t *)(p)); + } + if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count))) + goto error; + break; + } + case 1: + if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count))) + goto error; + break; + case 0: + if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count))) + goto error; + break; + default: + goto error; + } + + DPCTLEvent_Wait(eref); + + DPCTLQueue_Delete(qref); + DPCTLEvent_Delete(eref); + + return mi; + +error: + DPCTLQueue_Delete(qref); + DPCTLEvent_Delete(eref); + DPCTLDeviceSelector_Delete(dselector); + DPCTLDevice_Delete(dref); + + return NULL; +} + /*----------------------------------------------------------------------------*/ /*--------- Helpers to get attributes out of a dpnp.ndarray PyObject ---------*/ /*----------------------------------------------------------------------------*/ @@ -487,12 +604,13 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj, arystruct_t *arystruct) { struct PyUSMArrayObject *arrayobj = NULL; - int i, ndim; + int i = 0, ndim = 0, exp = 0; npy_intp *shape = NULL, *strides = NULL; - npy_intp *p = NULL, nitems, itemsize; + npy_intp *p = NULL, nitems; void *data = NULL; DPCTLSyclQueueRef qref = NULL; PyGILState_STATE gstate; + npy_intp itemsize = 0; // Increment the ref count on obj to prevent CPython from garbage // collecting the array. @@ -546,20 +664,29 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj, p = arystruct->shape_and_strides; + // Calculate the exponent from the arystruct->itemsize as we know + // itemsize is a power of two + while (itemsize >>= 1) + exp++; + for (i = 0; i < ndim; ++i, ++p) *p = shape[i]; - // DPCTL returns a NULL pointer if the array is contiguous + // DPCTL returns a NULL pointer if the array is contiguous. dpctl stores + // strides as number of elements and Numba stores strides as bytes, for + // that reason we are multiplying stride by itemsize when unboxing the + // external array. + // FIXME: Stride computation should check order and adjust how strides are // calculated. Right now strides are assuming that order is C contigous. if (strides) { for (i = 0; i < ndim; ++i, ++p) { - *p = strides[i]; + *p = strides[i] << exp; } } else { for (i = 1; i < ndim; ++i, ++p) { - *p = shape[i]; + *p = shape[i] << exp; } *p = 1; } @@ -598,11 +725,12 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct, int ndim, PyArray_Descr *descr) { - int i; - npy_intp *p; + int i = 0, exp = 0; + npy_intp *p = NULL; npy_intp *shape = NULL, *strides = NULL; PyObject *array = arystruct->parent; struct PyUSMArrayObject *arrayobj = NULL; + npy_intp itemsize = 0; NRT_Debug(nrt_debug_print("DPEXRT-DEBUG: In try_to_return_parent.\n")); @@ -623,9 +751,16 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct, if (shape[i] != *p) return NULL; } - + // Calculate the exponent from the arystruct->itemsize as we know + // itemsize is a power of two + itemsize = arystruct->itemsize; + while (itemsize >>= 1) + exp++; + // dpctl stores strides as number of elements and Numba stores strides as + // bytes, for that reason we are multiplying stride by itemsize when + // unboxing the external array. if (strides) { - if (strides[i] != *p) + if (strides[i] << exp != *p) return NULL; } else { @@ -680,6 +815,8 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct, npy_intp *shape = NULL, *strides = NULL; int typenum = 0; int status = 0; + int exp = 0; + npy_intp itemsize = 0; NRT_Debug(nrt_debug_print( "DPEXRT-DEBUG: In DPEXRT_sycl_usm_ndarray_to_python_acqref.\n")); @@ -750,7 +887,20 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct, } shape = arystruct->shape_and_strides; - strides = shape + ndim; + + // Calculate the exponent from the arystruct->itemsize as we know + // itemsize is a power of two + itemsize = arystruct->itemsize; + while (itemsize >>= 1) + exp++; + + // Numba internally stores strides as bytes and not as elements. Divide + // the stride by itemsize to get number of elements. + for (size_t idx = ndim; idx < 2 * ((size_t)ndim); ++idx) + arystruct->shape_and_strides[idx] = + arystruct->shape_and_strides[idx] >> exp; + strides = (shape + ndim); + typenum = descr->type_num; usm_ndarr_obj = UsmNDArray_MakeFromPtr( ndim, shape, typenum, strides, (DPCTLSyclUSMRef)arystruct->data, @@ -845,6 +995,7 @@ static PyObject *build_c_helpers_dict(void) _declpointer("DPEXRT_sycl_usm_ndarray_to_python_acqref", &DPEXRT_sycl_usm_ndarray_to_python_acqref); _declpointer("DPEXRT_MemInfo_alloc", &DPEXRT_MemInfo_alloc); + _declpointer("DPEXRT_MemInfo_fill", &DPEXRT_MemInfo_fill); _declpointer("NRT_ExternalAllocator_new_for_usm", &NRT_ExternalAllocator_new_for_usm); @@ -895,7 +1046,8 @@ MOD_INIT(_dpexrt_python) PyLong_FromVoidPtr(&DPEXRT_sycl_usm_ndarray_to_python_acqref)); PyModule_AddObject(m, "DPEXRT_MemInfo_alloc", PyLong_FromVoidPtr(&DPEXRT_MemInfo_alloc)); - + PyModule_AddObject(m, "DPEXRT_MemInfo_fill", + PyLong_FromVoidPtr(&DPEXRT_MemInfo_fill)); PyModule_AddObject(m, "c_helpers", build_c_helpers_dict()); return MOD_SUCCESS_VAL(m); } diff --git a/numba_dpex/core/runtime/context.py b/numba_dpex/core/runtime/context.py index 605c35cc05..10fc0489f7 100644 --- a/numba_dpex/core/runtime/context.py +++ b/numba_dpex/core/runtime/context.py @@ -31,6 +31,17 @@ def wrap(self, builder, *args, **kwargs): @_check_null_result def meminfo_alloc(self, builder, size, usm_type, device): + """A wrapped caller for meminfo_alloc_unchecked() with null check.""" + return self.meminfo_alloc_unchecked(builder, size, usm_type, device) + + @_check_null_result + def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device): + """A wrapped caller for meminfo_fill_unchecked() with null check.""" + return self.meminfo_fill_unchecked( + builder, meminfo, itemsize, is_float, value, device + ) + + def meminfo_alloc_unchecked(self, builder, size, usm_type, device): """Allocate a new MemInfo with a data payload of `size` bytes. The result of the call is checked and if it is NULL, i.e. allocation @@ -49,26 +60,50 @@ def meminfo_alloc(self, builder, size, usm_type, device): Returns: A pointer to the MemInfo is returned. """ + mod = builder.module + u64 = ir.IntType(64) + fnty = ir.FunctionType( + cgutils.voidptr_t, [cgutils.intp_t, u64, cgutils.voidptr_t] + ) + fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_alloc") + fn.return_value.add_attribute("noalias") - return self.meminfo_alloc_unchecked(builder, size, usm_type, device) + ret = builder.call(fn, [size, usm_type, device]) - def meminfo_alloc_unchecked(self, builder, size, usm_type, device): - """ - Allocate a new MemInfo with a data payload of `size` bytes. + return ret - A pointer to the MemInfo is returned. + def meminfo_fill_unchecked( + self, builder, meminfo, itemsize, is_float, value, device + ): + """Fills an allocated `MemInfo` with the value specified. - Returns NULL to indicate error/failure to allocate. + The result of the call is checked and if it is `NULL`, i.e. the fill + operation failed, then a `MemoryError` is raised. If the fill operation + is succeeded then a pointer to the `MemInfo` is returned. + + Args: + builder (llvmlite.ir.builder.IRBuilder): LLVM IR builder + meminfo (llvmlite.ir.instructions.LoadInstr): LLVM uint64 value + specifying the size in bytes for the data payload. + itemsize (llvmlite.ir.values.Constant): An LLVM Constant value + specifying the size of the each data item allocated by the + usm allocator. + device (llvmlite.ir.values.FormattedConstant): An LLVM ArrayType + storing a const string for a DPC++ filter selector string. + + Returns: A pointer to the `MemInfo` is returned. """ mod = builder.module u64 = ir.IntType(64) + b = ir.IntType(1) fnty = ir.FunctionType( - cgutils.voidptr_t, [cgutils.intp_t, u64, cgutils.voidptr_t] + cgutils.voidptr_t, + [cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t], ) - fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_alloc") + fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill") fn.return_value.add_attribute("noalias") - ret = builder.call(fn, [size, usm_type, device]) + ret = builder.call(fn, [meminfo, itemsize, is_float, value, device]) return ret diff --git a/numba_dpex/dpnp_iface/arrayobj.py b/numba_dpex/dpnp_iface/arrayobj.py index cdfe66e634..5aa019637b 100644 --- a/numba_dpex/dpnp_iface/arrayobj.py +++ b/numba_dpex/dpnp_iface/arrayobj.py @@ -7,6 +7,7 @@ from llvmlite.ir import Constant from numba import errors, types from numba.core import cgutils +from numba.core.types.scalars import Float from numba.core.typing import signature from numba.core.typing.npydecl import parse_dtype as ty_parse_dtype from numba.core.typing.npydecl import parse_shape @@ -218,14 +219,61 @@ def impl_dpnp_empty( sig = ty_retty(ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref) - def codegen(cgctx, builder, sig, llargs): - arrtype = _parse_empty_args(cgctx, builder, sig, llargs) - ary = _empty_nd_impl(cgctx, builder, *arrtype) + def codegen(context, builder, sig, llargs): + arrtype = _parse_empty_args(context, builder, sig, llargs) + ary = _empty_nd_impl(context, builder, *arrtype) return ary._getvalue() return sig, codegen +def aryobj_fill(context, builder, sig, llargs, value): + arrtype = _parse_empty_args(context, builder, sig, llargs) + ary = _empty_nd_impl(context, builder, *arrtype) + itemsize = context.get_constant( + types.intp, get_itemsize(context, arrtype[0]) + ) + device = context.insert_const_string(builder.module, arrtype[0].device) + value = context.get_constant(types.int8, value) + if isinstance(arrtype[0].dtype, Float): + is_float = context.get_constant(types.boolean, 1) + else: + is_float = context.get_constant(types.boolean, 0) + dpexrtCtx = dpexrt.DpexRTContext(context) + dpexrtCtx.meminfo_fill( + builder, ary.meminfo, itemsize, is_float, value, device + ) + return ary._getvalue() + + +@intrinsic +def impl_dpnp_zeros( + tyctx, ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref +): + ty_retty = ty_retty_ref.instance_type + + sig = ty_retty(ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref) + + def codegen(context, builder, sig, llargs): + return aryobj_fill(context, builder, sig, llargs, 0) + + return sig, codegen + + +@intrinsic +def impl_dpnp_ones( + tyctx, ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref +): + ty_retty = ty_retty_ref.instance_type + + sig = ty_retty(ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref) + + def codegen(context, builder, sig, llargs): + return aryobj_fill(context, builder, sig, llargs, 1) + + return sig, codegen + + # ------------------------------------------------------------------------------ # Dpnp array constructor overloads @@ -303,3 +351,111 @@ def impl( f"Cannot parse input types to function dpnp.empty({shape}, {dtype})" ) raise errors.TypingError(msg) + + +@overload(dpnp.zeros, prefer_literal=True) +def ol_dpnp_zeros( + shape, dtype=None, usm_type=None, device=None, sycl_queue=None +): + if sycl_queue: + raise errors.TypingError( + "The sycl_queue keyword is not yet supported by dpnp.empty inside " + "a dpjit decorated function." + ) + + ndim = parse_shape(shape) + if not ndim: + raise errors.TypingError("Could not infer the rank of the ndarray") + + # If a dtype value was passed in, then try to convert it to the + # coresponding Numba type. If None was passed, the default, then pass None + # to the DpnpNdArray constructor. The default dtype will be derived based + # on the behavior defined in dpctl.tensor.usm_ndarray. + if not is_nonelike(dtype): + nb_dtype = ty_parse_dtype(dtype) + else: + nb_dtype = None + + if usm_type is not None: + usm_type = _parse_usm_type(usm_type) + else: + usm_type = "device" + + if device is not None: + device = _parse_device_filter_string(device) + else: + device = "unknown" + + if ndim is not None: + retty = DpnpNdArray( + dtype=nb_dtype, + ndim=ndim, + usm_type=usm_type, + device=device, + ) + + def impl( + shape, dtype=None, usm_type=None, device=None, sycl_queue=None + ): + return impl_dpnp_zeros(shape, dtype, usm_type, device, retty) + + return impl + else: + msg = ( + f"Cannot parse input types to function dpnp.empty({shape}, {dtype})" + ) + raise errors.TypingError(msg) + + +@overload(dpnp.ones, prefer_literal=True) +def ol_dpnp_ones( + shape, dtype=None, usm_type=None, device=None, sycl_queue=None +): + if sycl_queue: + raise errors.TypingError( + "The sycl_queue keyword is not yet supported by dpnp.empty inside " + "a dpjit decorated function." + ) + + ndim = parse_shape(shape) + if not ndim: + raise errors.TypingError("Could not infer the rank of the ndarray") + + # If a dtype value was passed in, then try to convert it to the + # coresponding Numba type. If None was passed, the default, then pass None + # to the DpnpNdArray constructor. The default dtype will be derived based + # on the behavior defined in dpctl.tensor.usm_ndarray. + if not is_nonelike(dtype): + nb_dtype = ty_parse_dtype(dtype) + else: + nb_dtype = None + + if usm_type is not None: + usm_type = _parse_usm_type(usm_type) + else: + usm_type = "device" + + if device is not None: + device = _parse_device_filter_string(device) + else: + device = "unknown" + + if ndim is not None: + retty = DpnpNdArray( + dtype=nb_dtype, + ndim=ndim, + usm_type=usm_type, + device=device, + ) + + def impl( + shape, dtype=None, usm_type=None, device=None, sycl_queue=None + ): + return impl_dpnp_ones(shape, dtype, usm_type, device, retty) + + return impl + else: + msg = ( + f"Cannot parse input types to function dpnp.empty({shape}, {dtype})" + ) + raise errors.TypingError(msg) diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py new file mode 100644 index 0000000000..fd69038728 --- /dev/null +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_ones.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dpnp ndarray constructors.""" + +import dpctl +import dpctl.tensor as dpt +import dpnp +import numpy +import pytest + +from numba_dpex import dpjit + +shapes = [11, (3, 7)] +dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64] +usm_types = ["device", "shared", "host"] +devices = ["cpu", "unknown"] + + +@pytest.mark.parametrize("shape", shapes) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("usm_type", usm_types) +@pytest.mark.parametrize("device", devices) +def test_dpnp_ones(shape, dtype, usm_type, device): + @dpjit + def func1(shape): + c = dpnp.ones( + shape=shape, dtype=dtype, usm_type=usm_type, device=device + ) + return c + + a = numpy.ones(shape, dtype=dtype) + + try: + c = func1(shape) + except Exception: + pytest.fail("Calling dpnp.empty inside dpjit failed") + + if len(c.shape) == 1: + assert c.shape[0] == shape + else: + assert c.shape == shape + + assert c.dtype == dtype + assert c.usm_type == usm_type + if device != "unknown": + assert ( + c.sycl_device.filter_string + == dpctl.SyclDevice(device).filter_string + ) + else: + c.sycl_device.filter_string == dpctl.SyclDevice().filter_string + + assert numpy.array_equal(dpt.asnumpy(c._array_obj), a) diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py new file mode 100644 index 0000000000..854cceb9bc --- /dev/null +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_zeros.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dpnp ndarray constructors.""" + +import dpctl +import dpctl.tensor as dpt +import dpnp +import numpy +import pytest + +from numba_dpex import dpjit + +shapes = [11, (3, 7)] +dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64] +usm_types = ["device", "shared", "host"] +devices = ["cpu", "unknown"] + + +@pytest.mark.parametrize("shape", shapes) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("usm_type", usm_types) +@pytest.mark.parametrize("device", devices) +def test_dpnp_zeros(shape, dtype, usm_type, device): + @dpjit + def func1(shape): + c = dpnp.zeros( + shape=shape, dtype=dtype, usm_type=usm_type, device=device + ) + return c + + a = numpy.zeros(shape, dtype=dtype) + + try: + c = func1(shape) + except Exception: + pytest.fail("Calling dpnp.empty inside dpjit failed") + + if len(c.shape) == 1: + assert c.shape[0] == shape + else: + assert c.shape == shape + + assert c.dtype == dtype + assert c.usm_type == usm_type + if device != "unknown": + assert ( + c.sycl_device.filter_string + == dpctl.SyclDevice(device).filter_string + ) + else: + c.sycl_device.filter_string == dpctl.SyclDevice().filter_string + + assert numpy.array_equal(dpt.asnumpy(c._array_obj), a)