Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A sycl::local_accessor-like API for numba-dpex kernel #1331

Merged
merged 12 commits into from
Mar 19, 2024
13 changes: 7 additions & 6 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
)


def _get_flattened_member_count(ty):
"""Return the number of fields in an instance of a given StructModel."""
def get_flattened_member_count(ty):
"""Returns the number of fields in an instance of a given StructModel."""

flattened_member_count = 0
members = ty._members
for member in members:
Expand Down Expand Up @@ -109,7 +110,7 @@ def flattened_field_count(self):
"""
Return the number of fields in an instance of a USMArrayDeviceModel.
"""
return _get_flattened_member_count(self)
return get_flattened_member_count(self)


class USMArrayHostModel(StructModel):
Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(self, dmm, fe_type):
@property
def flattened_field_count(self):
"""Return the number of fields in an instance of a USMArrayHostModel."""
return _get_flattened_member_count(self)
return get_flattened_member_count(self)


class SyclQueueModel(StructModel):
Expand Down Expand Up @@ -223,7 +224,7 @@ def __init__(self, dmm, fe_type):
@property
def flattened_field_count(self):
"""Return the number of fields in an instance of a RangeModel."""
return _get_flattened_member_count(self)
return get_flattened_member_count(self)


class NdRangeModel(StructModel):
Expand All @@ -246,7 +247,7 @@ def __init__(self, dmm, fe_type):
@property
def flattened_field_count(self):
"""Return the number of fields in an instance of a NdRangeModel."""
return _get_flattened_member_count(self)
return get_flattened_member_count(self)


def _init_data_model_manager() -> datamodel.DataModelManager:
Expand Down
84 changes: 84 additions & 0 deletions numba_dpex/core/types/kernel_api/local_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba.core import cgutils
from numba.core.types import Type, UniTuple, intp
from numba.extending import NativeValue, unbox
from numba.np import numpy_support

from numba_dpex.core.types import USMNdArray
from numba_dpex.utils import address_space as AddressSpace


class DpctlMDLocalAccessorType(Type):
"""numba-dpex internal type to represent a dpctl SyclInterface type
`MDLocalAccessorTy`.
"""

def __init__(self):
super().__init__(name="DpctlMDLocalAccessor")


class LocalAccessorType(USMNdArray):
"""numba-dpex internal type to represent a Python object of
:class:`numba_dpex.experimental.kernel_iface.LocalAccessor`.
"""

def __init__(self, ndim, dtype):
try:
if isinstance(dtype, Type):
parsed_dtype = dtype
else:
parsed_dtype = numpy_support.from_dtype(dtype)
except NotImplementedError as exc:
raise ValueError(f"Unsupported array dtype: {dtype}") from exc

type_name = (
f"LocalAccessor(dtype={parsed_dtype}, ndim={ndim}, "
f"address_space={AddressSpace.LOCAL})"
)

super().__init__(
ndim=ndim,
layout="C",
dtype=parsed_dtype,
addrspace=AddressSpace.LOCAL,
name=type_name,
)

def cast_python_value(self, args):
"""The helper function is not overloaded and using it on the
LocalAccessorType throws a NotImplementedError.
"""
raise NotImplementedError


@unbox(LocalAccessorType)
def unbox_local_accessor(typ, obj, c): # pylint: disable=unused-argument
"""Unboxes a Python LocalAccessor PyObject* into a numba-dpex internal
representation.

A LocalAccessor object is represented internally in numba-dpex with the
same data model as a numpy.ndarray. It is done as a LocalAccessor object
serves only as a placeholder type when passed to ``call_kernel`` and the
data buffer should never be accessed inside a host-side compiled function
such as ``call_kernel``.

When a LocalAccessor object is passed as an argument to a kernel function
it uses the USMArrayDeviceModel. Doing so allows numba-dpex to correctly
generate the kernel signature passing in a pointer in the local address
space.
"""
shape = c.pyapi.object_getattr_string(obj, "_shape")
local_accessor = cgutils.create_struct_proxy(typ)(c.context, c.builder)

ty_unituple = UniTuple(intp, typ.ndim)
ll_shape = c.unbox(ty_unituple, shape)
local_accessor.shape = ll_shape.value

return NativeValue(
c.builder.load(local_accessor._getpointer()),
is_error=ll_shape.is_error,
cleanup=ll_shape.cleanup,
)
135 changes: 132 additions & 3 deletions numba_dpex/core/utils/kernel_flattened_args_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
object.
"""

from functools import reduce
from math import ceil
from typing import NamedTuple

import dpctl
from llvmlite import ir as llvmir
from numba.core import types
from numba.core import cgutils, types
from numba.core.cpu import CPUContext

from numba_dpex import utils
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.types.kernel_api.local_accessor import (
DpctlMDLocalAccessorType,
LocalAccessorType,
)
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum


Expand Down Expand Up @@ -70,8 +77,14 @@ def add_argument(
arg_type,
arg_packed_llvm_val,
):
"""Add kernel argument that need to be flatten."""
if isinstance(arg_type, USMNdArray):
"""Add flattened representation of a kernel argument."""
if isinstance(arg_type, LocalAccessorType):
self._kernel_arg_list.extend(
self._build_local_accessor_arg(
arg_type, llvm_val=arg_packed_llvm_val
)
)
elif isinstance(arg_type, USMNdArray):
self._kernel_arg_list.extend(
self._build_array_arg(
arg_type, llvm_array_val=arg_packed_llvm_val
Expand Down Expand Up @@ -213,6 +226,121 @@ def _store_val_into_struct(self, struct_ref, index, val):
),
)

def _build_local_accessor_metadata_arg(
self, llvm_val, arg_type: LocalAccessorType, data_attr_ty
):
"""Handles the special case of building the kernel argument for the data
attribute of a kernel_api.LocalAccessor object.

A kernel_api.LocalAccessor conceptually represents a device-only memory
allocation. The mock kernel_api.LocalAccessor uses a numpy.ndarray to
represent the data allocation. The numpy.ndarray cannot be passed to the
kernel and is ignored when building the kernel argument. Instead, a
struct is allocated to store the metadata about the size of the device
memory allocation and a reference to the struct is passed to the
DPCTLQueue_Submit call. The DPCTLQueue_Submit then constructs a
sycl::local_accessor object using the metadata and passes the
sycl::local_accessor as the kernel argument, letting the DPC++ runtime
handle proper device memory allocation.
"""

ndim = arg_type.ndim

md_proxy = cgutils.create_struct_proxy(DpctlMDLocalAccessorType())(
self._context,
self._builder,
)
la_proxy = cgutils.create_struct_proxy(arg_type)(
self._context, self._builder, value=self._builder.load(llvm_val)
)

md_proxy.ndim = self._context.get_constant(types.int64, ndim)
md_proxy.dpctl_type_id = numba_type_to_dpctl_typenum(
self._context, data_attr_ty.dtype
)
for i, val in enumerate(
cgutils.unpack_tuple(self._builder, la_proxy.shape)
):
setattr(md_proxy, f"dim{i}", val)

return self._build_arg(
llvm_val=md_proxy._getpointer(),
numba_type=LocalAccessorType(
ndim, dpctl.tensor.dtype(data_attr_ty.dtype.name)
),
)

def _build_local_accessor_arg(self, arg_type: LocalAccessorType, llvm_val):
"""Creates a list of kernel LLVM Values for an unpacked USMNdArray
kernel argument from the local accessor.

Method generates UsmNdArray fields from local accessor type and value.
"""
# TODO: move extra values build on device side of codegen.
ndim = arg_type.ndim
la_proxy = cgutils.create_struct_proxy(arg_type)(
self._context, self._builder, value=self._builder.load(llvm_val)
)
shape = cgutils.unpack_tuple(self._builder, la_proxy.shape)
ll_size = reduce(self._builder.mul, shape)

size_ptr = cgutils.alloca_once_value(self._builder, ll_size)
itemsize = self._context.get_constant(
types.intp, ceil(arg_type.dtype.bitwidth / types.byte.bitwidth)
)
itemsize_ptr = cgutils.alloca_once_value(self._builder, itemsize)

kernel_arg_list = []

kernel_dm = self._kernel_dmm.lookup(arg_type)

kernel_arg_list.extend(
self._build_arg(
llvm_val=size_ptr,
numba_type=kernel_dm.get_member_fe_type("nitems"),
)
)

# Argument itemsize
kernel_arg_list.extend(
self._build_arg(
llvm_val=itemsize_ptr,
numba_type=kernel_dm.get_member_fe_type("itemsize"),
)
)

# Argument data
data_attr_ty = kernel_dm.get_member_fe_type("data")

kernel_arg_list.extend(
self._build_local_accessor_metadata_arg(
llvm_val=llvm_val,
arg_type=arg_type,
data_attr_ty=data_attr_ty,
)
)

# Arguments for shape
for val in shape:
shape_ptr = cgutils.alloca_once_value(self._builder, val)
kernel_arg_list.extend(
self._build_arg(
llvm_val=shape_ptr,
numba_type=types.int64,
)
)

# Arguments for strides
for i in range(ndim):
kernel_arg_list.extend(
self._build_arg(
llvm_val=itemsize_ptr,
numba_type=types.int64,
)
)

return kernel_arg_list

def _build_array_arg(self, arg_type, llvm_array_val):
"""Creates a list of LLVM Values for an unpacked USMNdArray kernel
argument.
Expand Down Expand Up @@ -240,6 +368,7 @@ def _build_array_arg(self, arg_type, llvm_array_val):
# Argument data
data_attr_pos = host_data_model.get_field_position("data")
data_attr_ty = kernel_data_model.get_member_fe_type("data")

kernel_arg_list.extend(
self._build_collections_attr_arg(
llvm_val=llvm_array_val,
Expand Down
5 changes: 4 additions & 1 deletion numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numba_dpex.core.exceptions import UnreachableError
from numba_dpex.core.runtime.context import DpexRTContext
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
from numba_dpex.core.types.kernel_api.ranges import NdRangeType, RangeType
from numba_dpex.core.utils.kernel_flattened_args_builder import (
KernelFlattenedArgsBuilder,
Expand Down Expand Up @@ -675,7 +676,9 @@ def get_queue_from_llvm_values(
the queue from the first USMNdArray argument can be extracted.
"""
for arg_num, argty in enumerate(ty_kernel_args):
if isinstance(argty, USMNdArray):
if isinstance(argty, USMNdArray) and not isinstance(
argty, LocalAccessorType
):
llvm_val = ll_kernel_args[arg_num]
datamodel = ctx.data_model_manager.lookup(argty)
sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue")
Expand Down
9 changes: 9 additions & 0 deletions numba_dpex/dpctl_iface/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numba.core import types

from numba_dpex import dpctl_sem_version
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType


def numba_type_to_dpctl_typenum(context, ty):
Expand Down Expand Up @@ -34,6 +35,10 @@ def numba_type_to_dpctl_typenum(context, ty):
return context.get_constant(
types.int32, kargty.dpctl_void_ptr.value
)
elif isinstance(ty, LocalAccessorType):
return context.get_constant(
types.int32, kargty.dpctl_local_accessor.value
)
else:
raise NotImplementedError
else:
Expand Down Expand Up @@ -61,5 +66,9 @@ def numba_type_to_dpctl_typenum(context, ty):
elif ty == types.voidptr or isinstance(ty, types.CPointer):
# DPCTL_VOID_PTR
return context.get_constant(types.int32, 15)
elif isinstance(ty, LocalAccessorType):
raise NotImplementedError(
"LocalAccessor args for kernels requires dpctl 0.17 or greater."
)
else:
raise NotImplementedError
1 change: 1 addition & 0 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from numba_dpex.core.boxing import *
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher

from . import typeof
from ._kernel_dpcpp_spirv_overloads import (
_atomic_fence_overloads,
_atomic_ref_overloads,
Expand Down
Loading
Loading