Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moves USMNdArrayType into numba_dpex/core. #851

Merged
merged 2 commits into from
Dec 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
"""
import numba.testing

# Re-export types itself
import numba_dpex.core.types as types

# Re-export all type names
from numba_dpex.core.types import *
from numba_dpex.interop import asarray
from numba_dpex.retarget import offload_to_sycl_device

Expand All @@ -23,4 +28,4 @@
__version__ = get_versions()["version"]
del get_versions

__all__ = ["offload_to_sycl_device"]
__all__ = ["offload_to_sycl_device"] + types.__all__
28 changes: 15 additions & 13 deletions numba_dpex/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

from numba_dpex import config
from numba_dpex.core.exceptions import KernelHasReturnValueError
from numba_dpex.core.types import Array
from numba_dpex.dpctl_iface import USMNdArrayType
from numba_dpex.core.types import Array, USMNdArray
from numba_dpex.dpctl_support import dpctl_version
from numba_dpex.parfor_diagnostics import ExtendedParforDiagnostics
from numba_dpex.utils import (
Expand Down Expand Up @@ -181,10 +180,13 @@ def compile_with_depx(pyfunc, return_type, args, is_kernel, debug=None):
def compile_kernel(sycl_queue, pyfunc, args, access_types, debug=None):
# For any array we only accept numba_dpex.types.Array
for arg in args:
if isinstance(arg, types.npytypes.Array) and not isinstance(arg, Array):
if isinstance(arg, types.npytypes.Array) and not (
isinstance(arg, Array) or isinstance(arg, USMNdArray)
):
raise TypeError(
"Only numba_dpex.core.types.Array objects are supported as "
+ "kernel arguments. Received %s" % (type(arg))
"Only numba_dpex.core.types.USMNdArray "
+ "objects are supported as kernel arguments. "
+ "Received %s" % (type(arg))
)

if config.DEBUG:
Expand Down Expand Up @@ -565,7 +567,7 @@ def _unpack_device_array_argument(
for ax in range(ndim):
kernelargs.append(ctypes.c_longlong(strides[ax]))

def _unpack_USMNdArrayType(self, val, kernelargs):
def _unpack_USMNdArray(self, val, kernelargs):
(
usm_mem,
total_size,
Expand Down Expand Up @@ -658,8 +660,8 @@ def _unpack_argument(

device_arrs.append(None)

if isinstance(ty, USMNdArrayType):
self._unpack_USMNdArrayType(val, kernelargs)
if isinstance(ty, USMNdArray):
self._unpack_USMNdArray(val, kernelargs)
elif isinstance(ty, types.Array):
self._unpack_Array(
val, sycl_queue, kernelargs, device_arrs, access_type
Expand Down Expand Up @@ -746,9 +748,9 @@ def _datatype_is_same(self, argtypes):
"""
array_type = None
for i, argtype in enumerate(argtypes):
arg_is_array_type = isinstance(
argtype, USMNdArrayType
) or isinstance(argtype, types.Array)
arg_is_array_type = isinstance(argtype, USMNdArray) or isinstance(
argtype, types.Array
)
if array_type is None and arg_is_array_type:
array_type = argtype
elif (
Expand All @@ -770,13 +772,13 @@ def __call__(self, *args, **kwargs):
if not uniform:
_raise_datatype_mixed_error(argtypes)

if type(array_type) == USMNdArrayType:
if type(array_type) == USMNdArray:
if dpctl.is_in_device_context():
warnings.warn(cfd_ctx_mgr_wrng_msg)

queues = []
for i, argtype in enumerate(argtypes):
if type(argtype) == USMNdArrayType:
if type(argtype) == USMNdArray:
memory = dpctl.memory.as_usm_memory(args[i])
if dpctl_version < (0, 12):
queue = memory._queue
Expand Down
3 changes: 3 additions & 0 deletions numba_dpex/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from .types import *
from .typing import *
3 changes: 3 additions & 0 deletions numba_dpex/core/datamodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
66 changes: 66 additions & 0 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba.core import datamodel, types
from numba.core.datamodel.models import PrimitiveModel, StructModel
from numba.core.extending import register_model

from numba_dpex.core.types import Array, USMNdArray
from numba_dpex.utils import address_space


class GenericPointerModel(PrimitiveModel):
def __init__(self, dmm, fe_type):
adrsp = (
fe_type.addrspace
if fe_type.addrspace is not None
else address_space.GLOBAL
)
be_type = dmm.lookup(fe_type.dtype).get_data_type().as_pointer(adrsp)
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)


class ArrayModel(StructModel):
"""A data model to represent a Dpex's array types in LLVM IR.

Dpex's ArrayModel is based on Numba's ArrayModel for NumPy arrays. The
dpex model adds an extra address space attribute to all pointer members
in the array.
"""

def __init__(self, dmm, fe_type):
ndim = fe_type.ndim
members = [
(
"meminfo",
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
),
(
"parent",
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
),
("nitems", types.intp),
("itemsize", types.intp),
(
"data",
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
),
("shape", types.UniTuple(types.intp, ndim)),
("strides", types.UniTuple(types.intp, ndim)),
]
super(ArrayModel, self).__init__(dmm, fe_type, members)


def _init_data_model_manager():
dmm = datamodel.default_manager.copy()
dmm.register(types.CPointer, GenericPointerModel)
dmm.register(Array, ArrayModel)
return dmm


dpex_data_model_manager = _init_data_model_manager()

# Register the USMNdArray type with the dpex ArrayModel
register_model(USMNdArray)(ArrayModel)
dpex_data_model_manager.register(USMNdArray, ArrayModel)
23 changes: 1 addition & 22 deletions numba_dpex/core/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from numba.core.target_extension import GPU, target_registry
from numba.core.utils import cached_property

from numba_dpex.core.types import Array, ArrayModel
from numba_dpex.core.datamodel.models import _init_data_model_manager
from numba_dpex.utils import (
address_space,
calling_conv,
Expand Down Expand Up @@ -95,27 +95,6 @@ def load_additional_registries(self):
self.install_registry(npydecl.registry)


class GenericPointerModel(datamodel.PrimitiveModel):
def __init__(self, dmm, fe_type):
adrsp = (
fe_type.addrspace
if fe_type.addrspace is not None
else address_space.GLOBAL
)
be_type = dmm.lookup(fe_type.dtype).get_data_type().as_pointer(adrsp)
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)


def _init_data_model_manager():
dmm = datamodel.default_manager.copy()
dmm.register(types.CPointer, GenericPointerModel)
dmm.register(Array, ArrayModel)
return dmm


spirv_data_model_manager = _init_data_model_manager()


class SyclDevice(GPU):
"""Mark the hardware target as SYCL Device."""

Expand Down
51 changes: 50 additions & 1 deletion numba_dpex/core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,53 @@
#
# SPDX-License-Identifier: Apache-2.0

from .array_type import *
from .array_type import Array
from .numba_types_short_names import (
b1,
bool_,
boolean,
double,
f4,
f8,
float32,
float64,
float_,
i4,
i8,
int32,
int64,
none,
u4,
u8,
uint32,
uint64,
void,
)
from .usm_ndarray_type import USMNdArray

usm_ndarray = USMNdArray

__all__ = [
"Array",
"USMNdArray",
"none",
"boolean",
"bool_",
"uint32",
"uint64",
"int32",
"int64",
"float32",
"float64",
"b1",
"i4",
"i8",
"u4",
"u8",
"f4",
"f8",
"float_",
"double",
"void",
"usm_ndarray",
]
24 changes: 0 additions & 24 deletions numba_dpex/core/types/array_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,3 @@ def box_type(self):

def is_precise(self):
return self.dtype.is_precise()


class ArrayModel(StructModel):
def __init__(self, dmm, fe_type):
ndim = fe_type.ndim
members = [
(
"meminfo",
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
),
(
"parent",
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
),
("nitems", types.intp),
("itemsize", types.intp),
(
"data",
types.CPointer(fe_type.dtype, addrspace=fe_type.addrspace),
),
("shape", types.UniTuple(types.intp, ndim)),
("strides", types.UniTuple(types.intp, ndim)),
]
super(ArrayModel, self).__init__(dmm, fe_type, members)
33 changes: 33 additions & 0 deletions numba_dpex/core/types/numba_types_short_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba.core.types import Boolean, Float, Integer, NoneType

# Short names for numba types supported in dpex kernel

none = NoneType("none")

boolean = bool_ = Boolean("bool")

uint32 = Integer("uint32")
uint64 = Integer("uint64")
int32 = Integer("int32")
int64 = Integer("int64")
float32 = Float("float32")
float64 = Float("float64")


# Aliases to NumPy type names

b1 = bool_
i4 = int32
i8 = int64
u4 = uint32
u8 = uint64
f4 = float32
f8 = float64

float_ = float32
double = float64
void = none
Loading