diff --git a/numba_dpex/__init__.py b/numba_dpex/__init__.py index 661aeb218b..073f57a1ff 100644 --- a/numba_dpex/__init__.py +++ b/numba_dpex/__init__.py @@ -5,33 +5,130 @@ """ The numba-dpex extension module adds data-parallel offload support to Numba. """ +import glob +import logging +import os +import platform as plt -import numba_dpex.core.dpjit_dispatcher -import numba_dpex.core.offload_dispatcher +import dpctl +import llvmlite.binding as ll +import numba +from numba.core import ir_utils +from numba.np import arrayobj +from numba.np.ufunc import array_exprs +from numba.np.ufunc.decorators import Vectorize + +from numba_dpex._patches import _empty_nd_impl, _is_ufunc, _mk_alloc +from numba_dpex.vectorizers import Vectorize as DpexVectorize + +# Monkey patches +array_exprs._is_ufunc = _is_ufunc +ir_utils.mk_alloc = _mk_alloc +arrayobj._empty_nd_impl = _empty_nd_impl + + +def load_dpctl_sycl_interface(): + """Permanently loads the ``DPCTLSyclInterface`` library provided by dpctl. + The ``DPCTLSyclInterface`` library provides C wrappers over SYCL functions + that are directly invoked from the LLVM modules generated by numba_dpex. + We load the library once at the time of initialization using llvmlite's + load_library_permanently function. + Raises: + ImportError: If the ``DPCTLSyclInterface`` library could not be loaded. + """ + + platform = plt.system() + if platform == "Windows": + paths = glob.glob( + os.path.join( + os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.dll" + ) + ) + else: + paths = glob.glob( + os.path.join( + os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.so.0" + ) + ) + + if len(paths) == 1: + ll.load_library_permanently(paths[0]) + else: + raise ImportError + + Vectorize.target_registry.ondemand["dpex"] = lambda: DpexVectorize + + +numba_version = tuple(map(int, numba.__version__.split(".")[:3])) +if numba_version < (0, 56, 4): + logging.warning( + "numba_dpex needs numba 0.56.4, using " + f"numba={numba_version} may cause unexpected behavior" + ) + + +dpctl_version = tuple(map(int, dpctl.__version__.split(".")[:2])) +if dpctl_version < (0, 14): + logging.warning( + "numba_dpex needs dpctl 0.14 or greater, using " + f"dpctl={dpctl_version} may cause unexpected behavior" + ) + + +import numba_dpex.core.dpjit_dispatcher # noqa E402 +import numba_dpex.core.offload_dispatcher # noqa E402 # Initialize the _dpexrt_python extension -import numba_dpex.core.runtime -import numba_dpex.core.targets.dpjit_target +import numba_dpex.core.runtime # noqa E402 +import numba_dpex.core.targets.dpjit_target # noqa E402 # Re-export types itself -import numba_dpex.core.types as types -from numba_dpex.core.kernel_interface.indexers import NdRange, Range +import numba_dpex.core.types as types # noqa E402 +from numba_dpex import config # noqa E402 +from numba_dpex.core.kernel_interface.indexers import ( # noqa E402 + NdRange, + Range, +) # Re-export all type names -from numba_dpex.core.types import * -from numba_dpex.retarget import offload_to_sycl_device - -from . import config +from numba_dpex.core.types import * # noqa E402 +from numba_dpex.retarget import offload_to_sycl_device # noqa E402 if config.HAS_NON_HOST_DEVICE: - from .device_init import * + # Re export + from .core.targets import dpjit_target, kernel_target + from .decorators import dpjit, func, kernel + + # We are importing dpnp stub module to make Numba recognize the + # module when we rename Numpy functions. + from .dpnp_iface.stubs import dpnp + from .ocl.stubs import ( + GLOBAL_MEM_FENCE, + LOCAL_MEM_FENCE, + atomic, + barrier, + get_global_id, + get_global_size, + get_group_id, + get_local_id, + get_local_size, + get_num_groups, + get_work_dim, + local, + mem_fence, + private, + sub_group_barrier, + ) + + DEFAULT_LOCAL_SIZE = [] + load_dpctl_sycl_interface() + del load_dpctl_sycl_interface else: raise ImportError("No non-host SYCL device found to execute kernels.") - -from ._version import get_versions +from numba_dpex._version import get_versions # noqa E402 __version__ = get_versions()["version"] del get_versions -__all__ = ["offload_to_sycl_device"] + types.__all__ + ["Range", "NdRange"] +__all__ = types.__all__ + ["offload_to_sycl_device"] + ["Range", "NdRange"] diff --git a/numba_dpex/_patches.py b/numba_dpex/_patches.py new file mode 100644 index 0000000000..0c74df084e --- /dev/null +++ b/numba_dpex/_patches.py @@ -0,0 +1,329 @@ +# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import numpy +from llvmlite import ir as llvmir +from llvmlite.ir import Constant +from numba.core import cgutils +from numba.core import config as numba_config +from numba.core import ir, types +from numba.core.ir_utils import ( + convert_size_to_var, + get_np_ufunc_typ, + mk_unique_var, +) +from numba.core.typing import signature +from numba.extending import intrinsic, overload_classmethod +from numba.np.arrayobj import ( + _call_allocator, + get_itemsize, + make_array, + populate_array, +) +from numba.np.ufunc.dufunc import DUFunc + +from numba_dpex.core.runtime import context as dpexrt +from numba_dpex.core.types import DpnpNdArray + +# Numpy array constructors + + +def _is_ufunc(func): + return isinstance(func, (numpy.ufunc, DUFunc)) or hasattr( + func, "is_dpnp_ufunc" + ) + + +def _mk_alloc( + typingctx, typemap, calltypes, lhs, size_var, dtype, scope, loc, lhs_typ +): + """generate an array allocation with np.empty() and return list of nodes. + size_var can be an int variable or tuple of int variables. + lhs_typ is the type of the array being allocated. + """ + out = [] + ndims = 1 + size_typ = types.intp + if isinstance(size_var, tuple): + if len(size_var) == 1: + size_var = size_var[0] + size_var = convert_size_to_var(size_var, typemap, scope, loc, out) + else: + # tuple_var = build_tuple([size_var...]) + ndims = len(size_var) + tuple_var = ir.Var(scope, mk_unique_var("$tuple_var"), loc) + if typemap: + typemap[tuple_var.name] = types.containers.UniTuple( + types.intp, ndims + ) + # constant sizes need to be assigned to vars + new_sizes = [ + convert_size_to_var(s, typemap, scope, loc, out) + for s in size_var + ] + tuple_call = ir.Expr.build_tuple(new_sizes, loc) + tuple_assign = ir.Assign(tuple_call, tuple_var, loc) + out.append(tuple_assign) + size_var = tuple_var + size_typ = types.containers.UniTuple(types.intp, ndims) + + if hasattr(lhs_typ, "__allocate__"): + return lhs_typ.__allocate__( + typingctx, + typemap, + calltypes, + lhs, + size_var, + dtype, + scope, + loc, + lhs_typ, + size_typ, + out, + ) + + # g_np_var = Global(numpy) + g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) + if typemap: + typemap[g_np_var.name] = types.misc.Module(numpy) + g_np = ir.Global("np", numpy, loc) + g_np_assign = ir.Assign(g_np, g_np_var, loc) + # attr call: empty_attr = getattr(g_np_var, empty) + empty_attr_call = ir.Expr.getattr(g_np_var, "empty", loc) + attr_var = ir.Var(scope, mk_unique_var("$empty_attr_attr"), loc) + if typemap: + typemap[attr_var.name] = get_np_ufunc_typ(numpy.empty) + attr_assign = ir.Assign(empty_attr_call, attr_var, loc) + # Assume str(dtype) returns a valid type + dtype_str = str(dtype) + # alloc call: lhs = empty_attr(size_var, typ_var) + typ_var = ir.Var(scope, mk_unique_var("$np_typ_var"), loc) + if typemap: + typemap[typ_var.name] = types.functions.NumberClass(dtype) + # If dtype is a datetime/timedelta with a unit, + # then it won't return a valid type and instead can be created + # with a string. i.e. "datetime64[ns]") + if ( + isinstance(dtype, (types.NPDatetime, types.NPTimedelta)) + and dtype.unit != "" + ): + typename_const = ir.Const(dtype_str, loc) + typ_var_assign = ir.Assign(typename_const, typ_var, loc) + else: + if dtype_str == "bool": + # empty doesn't like 'bool' sometimes (e.g. kmeans example) + dtype_str = "bool_" + np_typ_getattr = ir.Expr.getattr(g_np_var, dtype_str, loc) + typ_var_assign = ir.Assign(np_typ_getattr, typ_var, loc) + alloc_call = ir.Expr.call(attr_var, [size_var, typ_var], (), loc) + + if calltypes: + cac = typemap[attr_var.name].get_call_type( + typingctx, [size_typ, types.functions.NumberClass(dtype)], {} + ) + # By default, all calls to "empty" are typed as returning a standard + # NumPy ndarray. If we are allocating a ndarray subclass here then + # just change the return type to be that of the subclass. + cac._return_type = ( + lhs_typ.copy(layout="C") if lhs_typ.layout == "F" else lhs_typ + ) + calltypes[alloc_call] = cac + if lhs_typ.layout == "F": + empty_c_typ = lhs_typ.copy(layout="C") + empty_c_var = ir.Var(scope, mk_unique_var("$empty_c_var"), loc) + if typemap: + typemap[empty_c_var.name] = lhs_typ.copy(layout="C") + empty_c_assign = ir.Assign(alloc_call, empty_c_var, loc) + + # attr call: asfortranarray = getattr(g_np_var, asfortranarray) + asfortranarray_attr_call = ir.Expr.getattr( + g_np_var, "asfortranarray", loc + ) + afa_attr_var = ir.Var( + scope, mk_unique_var("$asfortran_array_attr"), loc + ) + if typemap: + typemap[afa_attr_var.name] = get_np_ufunc_typ(numpy.asfortranarray) + afa_attr_assign = ir.Assign(asfortranarray_attr_call, afa_attr_var, loc) + # call asfortranarray + asfortranarray_call = ir.Expr.call(afa_attr_var, [empty_c_var], (), loc) + if calltypes: + calltypes[asfortranarray_call] = typemap[ + afa_attr_var.name + ].get_call_type(typingctx, [empty_c_typ], {}) + + asfortranarray_assign = ir.Assign(asfortranarray_call, lhs, loc) + + out.extend( + [ + g_np_assign, + attr_assign, + typ_var_assign, + empty_c_assign, + afa_attr_assign, + asfortranarray_assign, + ] + ) + else: + alloc_assign = ir.Assign(alloc_call, lhs, loc) + out.extend([g_np_assign, attr_assign, typ_var_assign, alloc_assign]) + + return out + + +def _empty_nd_impl(context, builder, arrtype, shapes): + """Utility function used for allocating a new array during LLVM code + generation (lowering). Given a target context, builder, array + type, and a tuple or list of lowered dimension sizes, returns a + LLVM value pointing at a Numba runtime allocated array. + """ + + arycls = make_array(arrtype) + ary = arycls(context, builder) + + datatype = context.get_data_type(arrtype.dtype) + itemsize = context.get_constant(types.intp, get_itemsize(context, arrtype)) + + # compute array length + arrlen = context.get_constant(types.intp, 1) + overflow = Constant(llvmir.IntType(1), 0) + for s in shapes: + arrlen_mult = builder.smul_with_overflow(arrlen, s) + arrlen = builder.extract_value(arrlen_mult, 0) + overflow = builder.or_(overflow, builder.extract_value(arrlen_mult, 1)) + + if arrtype.ndim == 0: + strides = () + elif arrtype.layout == "C": + strides = [itemsize] + for dimension_size in reversed(shapes[1:]): + strides.append(builder.mul(strides[-1], dimension_size)) + strides = tuple(reversed(strides)) + elif arrtype.layout == "F": + strides = [itemsize] + for dimension_size in shapes[:-1]: + strides.append(builder.mul(strides[-1], dimension_size)) + strides = tuple(strides) + else: + raise NotImplementedError( + "Don't know how to allocate array with layout '{0}'.".format( + arrtype.layout + ) + ) + + # Check overflow, numpy also does this after checking order + allocsize_mult = builder.smul_with_overflow(arrlen, itemsize) + allocsize = builder.extract_value(allocsize_mult, 0) + overflow = builder.or_(overflow, builder.extract_value(allocsize_mult, 1)) + + with builder.if_then(overflow, likely=False): + # Raise same error as numpy, see: + # https://github.com/numpy/numpy/blob/2a488fe76a0f732dc418d03b452caace161673da/numpy/core/src/multiarray/ctors.c#L1095-L1101 # noqa: E501 + context.call_conv.return_user_exc( + builder, + ValueError, + ( + "array is too big; `arr.size * arr.dtype.itemsize` is larger than" + " the maximum possible size.", + ), + ) + + if isinstance(arrtype, DpnpNdArray): + usm_ty = arrtype.usm_type + usm_ty_val = 0 + if usm_ty == "device": + usm_ty_val = 1 + elif usm_ty == "shared": + usm_ty_val = 2 + elif usm_ty == "host": + usm_ty_val = 3 + usm_type = context.get_constant(types.uint64, usm_ty_val) + device = context.insert_const_string(builder.module, arrtype.device) + + args = ( + context.get_dummy_value(), + allocsize, + usm_type, + device, + ) + mip = types.MemInfoPointer(types.voidptr) + arytypeclass = types.TypeRef(type(arrtype)) + sig = signature( + mip, + arytypeclass, + types.intp, + types.uint64, + types.voidptr, + ) + from numba_dpex.decorators import dpjit + + numba_config.DISABLE_PERFORMANCE_WARNINGS = 0 + op = dpjit(_call_usm_allocator) + fnop = context.typing_context.resolve_value_type(op) + # The _call_usm_allocator function will be compiled and added to registry + # when the get_call_type function is invoked. + fnop.get_call_type(context.typing_context, sig.args, {}) + numba_config.DISABLE_PERFORMANCE_WARNINGS = 1 + eqfn = context.get_function(fnop, sig) + meminfo = eqfn(builder, args) + else: + dtype = arrtype.dtype + align_val = context.get_preferred_array_alignment(dtype) + align = context.get_constant(types.uint32, align_val) + args = (context.get_dummy_value(), allocsize, align) + + mip = types.MemInfoPointer(types.voidptr) + arytypeclass = types.TypeRef(type(arrtype)) + argtypes = signature(mip, arytypeclass, types.intp, types.uint32) + + meminfo = context.compile_internal( + builder, _call_allocator, argtypes, args + ) + + data = context.nrt.meminfo_data(builder, meminfo) + + intp_t = context.get_value_type(types.intp) + shape_array = cgutils.pack_array(builder, shapes, ty=intp_t) + strides_array = cgutils.pack_array(builder, strides, ty=intp_t) + + populate_array( + ary, + data=builder.bitcast(data, datatype.as_pointer()), + shape=shape_array, + strides=strides_array, + itemsize=itemsize, + meminfo=meminfo, + ) + + return ary + + +@overload_classmethod(DpnpNdArray, "_usm_allocate") +def _ol_array_allocate(cls, allocsize, usm_type, device): + """Implements an allocator for dpnp.ndarrays.""" + + def impl(cls, allocsize, usm_type, device): + return intrin_usm_alloc(allocsize, usm_type, device) + + return impl + + +def _call_usm_allocator(arrtype, size, usm_type, device): + """Trampoline to call the intrinsic used for allocation""" + return arrtype._usm_allocate(size, usm_type, device) + + +@intrinsic +def intrin_usm_alloc(typingctx, allocsize, usm_type, device): + """Intrinsic to call into the allocator for Array""" + + def codegen(context, builder, signature, args): + [allocsize, usm_type, device] = args + dpexrtCtx = dpexrt.DpexRTContext(context) + meminfo = dpexrtCtx.meminfo_alloc(builder, allocsize, usm_type, device) + return meminfo + + mip = types.MemInfoPointer(types.voidptr) # return untyped pointer + sig = signature(mip, allocsize, usm_type, device) + return sig, codegen diff --git a/numba_dpex/config.py b/numba_dpex/config.py index 4d10ec1554..7b66bf3ba8 100644 --- a/numba_dpex/config.py +++ b/numba_dpex/config.py @@ -5,45 +5,10 @@ import logging import os +import dpctl from numba.core import config -def _ensure_dpctl(): - """ - Make sure dpctl has supported versions. - """ - from numba_dpex.dpctl_support import dpctl_version - - if dpctl_version < (0, 14): - logging.warning( - "numba_dpex needs dpctl 0.14 or greater, using " - f"dpctl={dpctl_version} may cause unexpected behavior" - ) - - -def _dpctl_has_non_host_device(): - """ - Ensure dpctl can create a default sycl device - """ - import dpctl - - try: - dpctl.select_default_device() - return True - except Exception: - msg = "dpctl could not find any non-host SYCL device on the system. " - msg += "A non-host SYCL device is required to use numba_dpex." - logging.exception(msg) - return False - - -_ensure_dpctl() - -# Set this config flag based on if dpctl is found or not. The config flags is -# used elsewhere inside Numba. -HAS_NON_HOST_DEVICE = _dpctl_has_non_host_device() - - def _readenv(name, ctor, default): """Original version from numba/core/config.py class _EnvReloader(): @@ -68,6 +33,17 @@ def __getattr__(name): return getattr(config, name) +# Ensure dpctl can create a default sycl device. +# Set this config flag based on if dpctl is found or not. +# The config flags is used elsewhere inside Numba. +try: + HAS_NON_HOST_DEVICE = dpctl.select_default_device() +except Exception: + logging.exception( + "dpctl could not find any non-host SYCL device on the system. " + + "A non-host SYCL device is required to use numba_dpex." + ) + # To save intermediate files generated by th compiler SAVE_IR_FILES = _readenv("NUMBA_DPEX_SAVE_IR_FILES", int, 0) diff --git a/numba_dpex/device_init.py b/numba_dpex/device_init.py deleted file mode 100644 index 14bb92aeef..0000000000 --- a/numba_dpex/device_init.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation -# -# SPDX-License-Identifier: Apache-2.0 - -# Re export -from .ocl.stubs import ( - GLOBAL_MEM_FENCE, - LOCAL_MEM_FENCE, - atomic, - barrier, - get_global_id, - get_global_size, - get_group_id, - get_local_id, - get_local_size, - get_num_groups, - get_work_dim, - local, - mem_fence, - private, - sub_group_barrier, -) - -""" -We are importing dpnp stub module to make Numba recognize the -module when we rename Numpy functions. -""" -from .dpnp_iface.stubs import dpnp - -DEFAULT_LOCAL_SIZE = [] - -import dpctl - -from . import initialize -from .core.targets import dpjit_target, kernel_target -from .decorators import dpjit, func, kernel - -initialize.load_dpctl_sycl_interface() diff --git a/numba_dpex/dpctl_support.py b/numba_dpex/dpctl_support.py deleted file mode 100644 index 67353df78a..0000000000 --- a/numba_dpex/dpctl_support.py +++ /dev/null @@ -1,21 +0,0 @@ -# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation -# -# SPDX-License-Identifier: Apache-2.0 - -import dpctl - - -def _parse_version(): - t = dpctl.__version__.split(".") - if len(t) > 1: - try: - return tuple(map(int, t[:2])) - except ValueError: - return (0, 0) - else: - return (0, 0) - - -dpctl_version = _parse_version() - -del _parse_version diff --git a/numba_dpex/dpnp_iface/_intrinsic.py b/numba_dpex/dpnp_iface/_intrinsic.py index 5cb70ecb95..3c2ad0ac98 100644 --- a/numba_dpex/dpnp_iface/_intrinsic.py +++ b/numba_dpex/dpnp_iface/_intrinsic.py @@ -2,158 +2,18 @@ # # SPDX-License-Identifier: Apache-2.0 -from llvmlite import ir -from llvmlite.ir import Constant from numba import types -from numba.core import cgutils from numba.core.typing import signature from numba.extending import intrinsic from numba.np.arrayobj import ( + _empty_nd_impl, _parse_empty_args, _parse_empty_like_args, get_itemsize, - make_array, - populate_array, ) from numba_dpex.core.runtime import context as dpexrt -from ..decorators import dpjit - - -@dpjit -# TODO: rename this to _call_allocator and see below -def _call_usm_allocator(arrtype, size, usm_type, device): - """Trampoline to call the intrinsic used for allocation""" - - return arrtype._usm_allocate(size, usm_type, device) - - -def _empty_nd_impl(context, builder, arrtype, shapes): - """Utility function used for allocating a new array during LLVM - code generation (lowering). - - Given a target context, builder, array type, and a tuple or list - of lowered dimension sizes, returns a LLVM value pointing at a - Numba runtime allocated array. - - Args: - context (numba.core.base.BaseContext): One of the class derived - from numba's BaseContext, e.g. CPUContext - builder (llvmlite.ir.builder.IRBuilder): IR builder object from - llvmlite. - arrtype (numba_dpex.core.types.dpnp_ndarray_type.DpnpNdArray): - An array type info to construct the actual array. - shapes (list): The dimension of the array. - - Raises: - NotImplementedError: If the layout of the array is not known. - - Returns: - numba.np.arrayobj.make_array..ArrayStruct: The constructed - array. - """ - - arycls = make_array(arrtype) - ary = arycls(context, builder) - - datatype = context.get_data_type(arrtype.dtype) - itemsize = context.get_constant(types.intp, get_itemsize(context, arrtype)) - - # compute array length - arrlen = context.get_constant(types.intp, 1) - overflow = Constant(ir.IntType(1), 0) - for s in shapes: - arrlen_mult = builder.smul_with_overflow(arrlen, s) - arrlen = builder.extract_value(arrlen_mult, 0) - overflow = builder.or_(overflow, builder.extract_value(arrlen_mult, 1)) - - if arrtype.ndim == 0: - strides = () - elif arrtype.layout == "C": - strides = [itemsize] - for dimension_size in reversed(shapes[1:]): - strides.append(builder.mul(strides[-1], dimension_size)) - strides = tuple(reversed(strides)) - elif arrtype.layout == "F": - strides = [itemsize] - for dimension_size in shapes[:-1]: - strides.append(builder.mul(strides[-1], dimension_size)) - strides = tuple(strides) - else: - raise NotImplementedError( - "Don't know how to allocate array with layout '{0}'.".format( - arrtype.layout - ) - ) - - # Check overflow, numpy also does this after checking order - allocsize_mult = builder.smul_with_overflow(arrlen, itemsize) - allocsize = builder.extract_value(allocsize_mult, 0) - overflow = builder.or_(overflow, builder.extract_value(allocsize_mult, 1)) - - with builder.if_then(overflow, likely=False): - # Raise same error as numpy, see: - # https://github.com/numpy/numpy/blob/2a488fe76a0f732dc418d03b452caace161673da/numpy/core/src/multiarray/ctors.c#L1095-L1101 # noqa: E501 - context.call_conv.return_user_exc( - builder, - ValueError, - ( - "array is too big; `arr.size * arr.dtype.itemsize` is larger " - "than the maximum possible size.", - ), - ) - - usm_ty = arrtype.usm_type - usm_ty_val = 0 - if usm_ty == "device": - usm_ty_val = 1 - elif usm_ty == "shared": - usm_ty_val = 2 - elif usm_ty == "host": - usm_ty_val = 3 - usm_type = context.get_constant(types.uint64, usm_ty_val) - device = context.insert_const_string(builder.module, arrtype.device) - - args = ( - context.get_dummy_value(), - allocsize, - usm_type, - device, - ) - mip = types.MemInfoPointer(types.voidptr) - arytypeclass = types.TypeRef(type(arrtype)) - sig = signature( - mip, - arytypeclass, - types.intp, - types.uint64, - types.voidptr, - ) - - op = _call_usm_allocator - fnop = context.typing_context.resolve_value_type(op) - # The _call_usm_allocator function will be compiled and added to registry - # when the get_call_type function is invoked. - fnop.get_call_type(context.typing_context, sig.args, {}) - eqfn = context.get_function(fnop, sig) - meminfo = eqfn(builder, args) - data = context.nrt.meminfo_data(builder, meminfo) - intp_t = context.get_value_type(types.intp) - shape_array = cgutils.pack_array(builder, shapes, ty=intp_t) - strides_array = cgutils.pack_array(builder, strides, ty=intp_t) - - populate_array( - ary, - data=builder.bitcast(data, datatype.as_pointer()), - shape=shape_array, - strides=strides_array, - itemsize=itemsize, - meminfo=meminfo, - ) - - return ary - def alloc_empty_arrayobj(context, builder, sig, llargs, is_like=False): """Construct an empty numba.np.arrayobj.make_array..ArrayStruct diff --git a/numba_dpex/initialize.py b/numba_dpex/initialize.py deleted file mode 100644 index f08da26c6b..0000000000 --- a/numba_dpex/initialize.py +++ /dev/null @@ -1,51 +0,0 @@ -# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation -# -# SPDX-License-Identifier: Apache-2.0 - -import os - -import llvmlite.binding as ll -from numba.np.ufunc.decorators import Vectorize - -from numba_dpex.vectorizers import Vectorize as DpexVectorize - - -def load_dpctl_sycl_interface(): - """Permanently loads the ``DPCTLSyclInterface`` library provided by dpctl. - - The ``DPCTLSyclInterface`` library provides C wrappers over SYCL functions - that are directly invoked from the LLVM modules generated by numba_dpex. - We load the library once at the time of initialization using llvmlite's - load_library_permanently function. - - Raises: - ImportError: If the ``DPCTLSyclInterface`` library could not be loaded. - """ - import glob - import platform as plt - - import dpctl - - platform = plt.system() - if platform == "Windows": - paths = glob.glob( - os.path.join( - os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.dll" - ) - ) - else: - paths = glob.glob( - os.path.join( - os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.so.0" - ) - ) - - if len(paths) == 1: - ll.load_library_permanently(paths[0]) - else: - raise ImportError - - def init_dpex_vectorize(): - return DpexVectorize - - Vectorize.target_registry.ondemand["dpex"] = init_dpex_vectorize diff --git a/numba_dpex/numba_support.py b/numba_dpex/numba_support.py deleted file mode 100644 index dabe87cc7a..0000000000 --- a/numba_dpex/numba_support.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation -# -# SPDX-License-Identifier: Apache-2.0 - -import numba as nb - -numba_version = tuple(map(int, nb.__version__.split(".")[:2])) diff --git a/numba_dpex/tests/_helper.py b/numba_dpex/tests/_helper.py index aae65b15f1..515a4a31d8 100644 --- a/numba_dpex/tests/_helper.py +++ b/numba_dpex/tests/_helper.py @@ -11,8 +11,7 @@ import pytest from numba.tests.support import captured_stdout -from numba_dpex import config -from numba_dpex.numba_support import numba_version +from numba_dpex import config, numba_version def has_opencl_gpu(): diff --git a/numba_dpex/tests/test_dpctl_api.py b/numba_dpex/tests/test_dpctl_api.py index 6768435865..2a9cc8eeb0 100644 --- a/numba_dpex/tests/test_dpctl_api.py +++ b/numba_dpex/tests/test_dpctl_api.py @@ -5,7 +5,7 @@ import dpctl import pytest -from numba_dpex.dpctl_support import dpctl_version +from numba_dpex import dpctl_version from numba_dpex.tests._helper import filter_strings