Skip to content

Commit

Permalink
Use lower instead of overload for private array
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Mar 8, 2024
1 parent a0a4c8c commit 66cd2c9
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

import llvmlite.ir as llvmir
from llvmlite.ir.builder import IRBuilder
from numba.core import cgutils
from numba.core import cgutils, types
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.core.typing.templates import Signature
from numba.extending import intrinsic, overload
from numba.extending import type_callable

from numba_dpex.core.types import USMNdArray
from numba_dpex.experimental.target import DpexExpKernelTypingContext
Expand All @@ -24,67 +24,12 @@
)
from numba_dpex.utils import address_space as AddressSpace

from ..target import DPEX_KERNEL_EXP_TARGET_NAME
from ._registy import lower


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_private_array_ctor(
ty_context, # pylint: disable=unused-argument
ty_shape,
ty_dtype,
ty_fill_zeros,
):
require_literal(ty_shape)
require_literal(ty_fill_zeros)

ty_array = USMNdArray(
dtype=_ty_parse_dtype(ty_dtype),
ndim=_ty_parse_shape(ty_shape),
layout="C",
addrspace=AddressSpace.PRIVATE,
)

sig = ty_array(ty_shape, ty_dtype, ty_fill_zeros)

def codegen(
context: DpexExpKernelTypingContext,
builder: IRBuilder,
sig: Signature,
args: list[llvmir.Value],
):
shape = args[0]
ty_shape = sig.args[0]
ty_fill_zeros = sig.args[-1]
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)

if ty_fill_zeros.literal_value:
cgutils.memset(
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0
)

return ary._getvalue() # pylint: disable=protected-access

return (
sig,
codegen,
)


@overload(
PrivateArray,
prefer_literal=True,
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_private_array_ctor(
shape,
dtype,
fill_zeros=False,
):
"""Overload of the constructor for the class
@type_callable(PrivateArray)
def type_interval(context): # pylint: disable=unused-argument
"""Sets type of the constructor for the class
class:`numba_dpex.kernel_api.PrivateArray`.
Raises:
Expand All @@ -94,12 +39,48 @@ def ol_private_array_ctor(
type.
"""

def ol_private_array_ctor_impl(
shape,
dtype,
fill_zeros=False,
):
# pylint: disable=no-value-for-parameter
return _intrinsic_private_array_ctor(shape, dtype, fill_zeros)
def typer(shape, dtype, fill_zeros=types.BooleanLiteral(False)):
require_literal(shape)
require_literal(fill_zeros)

return USMNdArray(
dtype=_ty_parse_dtype(dtype),
ndim=_ty_parse_shape(shape),
layout="C",
addrspace=AddressSpace.PRIVATE,
)

return typer


@lower(PrivateArray, types.IntegerLiteral, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.Tuple, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.UniTuple, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.IntegerLiteral, types.Any)
@lower(PrivateArray, types.Tuple, types.Any)
@lower(PrivateArray, types.UniTuple, types.Any)
def dpex_private_array_lower(
context: DpexExpKernelTypingContext,
builder: IRBuilder,
sig: Signature,
args: list[llvmir.Value],
):
"""Implements lower for the class:`numba_dpex.kernel_api.PrivateArray`"""
shape = args[0]
ty_shape = sig.args[0]
if len(sig.args) == 3:
fill_zeros = sig.args[-1].literal_value
else:
fill_zeros = False
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)

if fill_zeros:
cgutils.memset(
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0
)

return ol_private_array_ctor_impl
return ary._getvalue() # pylint: disable=protected-access
12 changes: 12 additions & 0 deletions numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_registy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Implements the SPIR-V overloads for the kernel_api.PrivateArray class.
"""

from numba.core.imputils import Registry

registry = Registry()
lower = registry.lower
4 changes: 3 additions & 1 deletion numba_dpex/kernel_api_impl/spirv/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def require_literal(literal_type: types.Type):

for i, _ in enumerate(literal_type):
if not isinstance(literal_type[i], types.Literal):
raise errors.TypingError("requires literal type")
raise errors.TypingError(
"requires each element of tuple literal type"
)


def make_spirv_array( # pylint: disable=too-many-arguments
Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,16 @@ def load_additional_registries(self):
# pylint: disable=import-outside-toplevel
from numba_dpex import printimpl
from numba_dpex.dpnp_iface import dpnpimpl
from numba_dpex.experimental._kernel_dpcpp_spirv_overloads._registy import (
registry as spirv_registry,
)
from numba_dpex.ocl import mathimpl, oclimpl

self.insert_func_defn(oclimpl.registry.functions)
self.insert_func_defn(mathimpl.registry.functions)
self.insert_func_defn(dpnpimpl.registry.functions)
self.install_registry(printimpl.registry)
self.install_registry(spirv_registry)
# Replace dpnp math functions with their OpenCL versions.
self.replace_dpnp_ufunc_with_ocl_intrinsics()

Expand Down

0 comments on commit 66cd2c9

Please sign in to comment.