Skip to content

Commit

Permalink
Merge pull request #888 from chudur-budur/github-871
Browse files Browse the repository at this point in the history
A more consistent kernel launch parameter syntax
  • Loading branch information
diptorupd authored Jan 31, 2023
2 parents c0e5be0 + 57c7d10 commit aa4eb5b
Show file tree
Hide file tree
Showing 28 changed files with 427 additions and 106 deletions.
2 changes: 1 addition & 1 deletion numba_dpex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __getattr__(name):
ENABLE_CACHE = _readenv("NUMBA_DPEX_ENABLE_CACHE", int, 1)
# Capacity of the cache, execute it like:
# NUMBA_DPEX_CACHE_SIZE=20 python <code>
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 10)
CACHE_SIZE = _readenv("NUMBA_DPEX_CACHE_SIZE", int, 128)

TESTING_SKIP_NO_DPNP = _readenv("NUMBA_DPEX_TESTING_SKIP_NO_DPNP", int, 0)
TESTING_SKIP_NO_DEBUGGING = _readenv(
Expand Down
116 changes: 77 additions & 39 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0


from collections.abc import Iterable
from inspect import signature
from warnings import warn

Expand Down Expand Up @@ -32,6 +33,7 @@
)
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.kernel_interface.utils import NdRange, Range
from numba_dpex.core.types import USMNdArray


Expand Down Expand Up @@ -468,51 +470,87 @@ def __getitem__(self, args):
global_range and local_range attributes initialized.
"""
if isinstance(args, int):
self._global_range = [args]
self._local_range = None
elif isinstance(args, tuple) or isinstance(args, list):
if len(args) == 1 and all(isinstance(v, int) for v in args):
self._global_range = list(args)
self._local_range = None
elif len(args) == 2:
gr = args[0]
lr = args[1]
if isinstance(gr, int):
self._global_range = [gr]
elif len(gr) != 0 and all(isinstance(v, int) for v in gr):
self._global_range = list(gr)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)
if isinstance(args, Range):
# we need inversions, see github issue #889
self._global_range = list(args)[::-1]
elif isinstance(args, NdRange):
# we need inversions, see github issue #889
self._global_range = list(args.global_range)[::-1]
self._local_range = list(args.local_range)[::-1]
else:
if (
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], int)
and isinstance(args[1], int)
):
warn(
"Ambiguous kernel launch paramters. If your data have "
+ "dimensions > 1, include a default/empty local_range:\n"
+ " <function>[(X,Y), numba_dpex.DEFAULT_LOCAL_RANGE](<params>)\n"
+ "otherwise your code might produce erroneous results.",
DeprecationWarning,
stacklevel=2,
)
self._global_range = [args[0]]
self._local_range = [args[1]]
return self

warn(
"The current syntax for specification of kernel lauch "
+ "parameters is deprecated. Users should set the kernel "
+ "parameters through Range/NdRange classes.\n"
+ "Example:\n"
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
+ " # for global range only\n"
+ " <function>[Range(X,Y)](<parameters>)\n"
+ " # or,\n"
+ " # for both global and local ranges\n"
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
DeprecationWarning,
stacklevel=2,
)

if isinstance(lr, int):
self._local_range = [lr]
elif isinstance(lr, list) and len(lr) == 0:
# deprecation warning
args = [args] if not isinstance(args, Iterable) else args
nargs = len(args)

# Check if the kernel enquing arguments are sane
if nargs < 1 or nargs > 2:
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)

g_range = (
[args[0]] if not isinstance(args[0], Iterable) else args[0]
)
# If the optional local size argument is provided
l_range = None
if nargs == 2:
if args[1] != []:
l_range = (
[args[1]]
if not isinstance(args[1], Iterable)
else args[1]
)
else:
warn(
"Specifying the local range as an empty list "
"(DEFAULT_LOCAL_SIZE) is deprecated. The kernel will "
"be executed as a basic data-parallel kernel over the "
"global range. Specify a valid local range to execute "
"the kernel as an ND-range kernel.",
"Empty local_range calls are deprecated. Please use Range/NdRange "
+ "to specify the kernel launch parameters:\n"
+ "Example:\n"
+ " from numba_dpex.core.kernel_interface.utils import Range,NdRange\n\n"
+ " # for global range only\n"
+ " <function>[Range(X,Y)](<parameters>)\n"
+ " # or,\n"
+ " # for both global and local ranges\n"
+ " <function>[NdRange((X,Y), (P,Q))](<parameters>)",
DeprecationWarning,
stacklevel=2,
)
self._local_range = None
elif len(lr) != 0 and all(isinstance(v, int) for v in lr):
self._local_range = list(lr)
else:
raise IllegalRangeValueError(kernel_name=self.kernel_name)
else:
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)
else:
raise InvalidKernelLaunchArgsError(kernel_name=self.kernel_name)

# FIXME:[::-1] is done as OpenCL and SYCl have different orders when
# it comes to specifying dimensions.
self._global_range = list(self._global_range)[::-1]
if self._local_range:
self._local_range = list(self._local_range)[::-1]
if len(g_range) < 1:
raise IllegalRangeValueError(kernel_name=self.kernel_name)

# we need inversions, see github issue #889
self._global_range = list(g_range)[::-1]
self._local_range = list(l_range)[::-1] if l_range else None

return self

Expand Down
218 changes: 218 additions & 0 deletions numba_dpex/core/kernel_interface/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from collections.abc import Iterable


class Range(tuple):
"""A data structure to encapsulate a single kernel lauch parameter.
The range is an abstraction that describes the number of elements
in each dimension of buffers and index spaces. It can contain
1, 2, or 3 numbers, dependending on the dimensionality of the
object it describes.
This is just a wrapper class on top of a 3-tuple. The kernel launch
parameter is consisted of three int's. This class basically mimics
the behavior of `sycl::range`.
"""

def __new__(cls, dim0, dim1=None, dim2=None):
"""Constructs a 1, 2, or 3 dimensional range.
Args:
dim0 (int): The range of the first dimension.
dim1 (int, optional): The range of second dimension.
Defaults to None.
dim2 (int, optional): The range of the third dimension.
Defaults to None.
Raises:
TypeError: If dim0 is not an int.
TypeError: If dim1 is not an int.
TypeError: If dim2 is not an int.
"""
if not isinstance(dim0, int):
raise TypeError("dim0 of a Range must be an int.")
_values = [dim0]
if dim1:
if not isinstance(dim1, int):
raise TypeError("dim1 of a Range must be an int.")
_values.append(dim1)
if dim2:
if not isinstance(dim2, int):
raise TypeError("dim2 of a Range must be an int.")
_values.append(dim2)
return super(Range, cls).__new__(cls, tuple(_values))

def get(self, index):
"""Returns the range of a single dimension.
Args:
index (int): The index of the dimension, i.e. [0,2]
Returns:
int: The range of the dimension indexed by `index`.
"""
return self[index]

def size(self):
"""Returns the size of a range.
Returns the size of a range by multiplying
the range of the individual dimensions.
Returns:
int: The size of a range.
"""
n = len(self)
if n > 2:
return self[0] * self[1] * self[2]
elif n > 1:
return self[0] * self[1]
else:
return self[0]


class NdRange:
"""A class to encapsulate all kernel launch parameters.
The NdRange defines the index space for a work group as well as
the global index space. It is passed to parallel_for to execute
a kernel on a set of work items.
This class basically contains two Range object, one for the global_range
and the other for the local_range. The global_range parameter contains
the global index space and the local_range parameter contains the index
space of a work group. This class mimics the behavior of `sycl::nd_range`
class.
"""

def __init__(self, global_size, local_size):
"""Constructor for NdRange class.
Args:
global_size (Range or tuple of int's): The values for
the global_range.
local_size (Range or tuple of int's, optional): The values for
the local_range. Defaults to None.
"""
if isinstance(global_size, Range):
self._global_range = global_size
elif isinstance(global_size, Iterable):
self._global_range = Range(*global_size)
else:
TypeError("Unknwon argument type for NdRange global_size.")

if isinstance(local_size, Range):
self._local_range = local_size
elif isinstance(local_size, Iterable):
self._local_range = Range(*local_size)
else:
TypeError("Unknwon argument type for NdRange local_size.")

@property
def global_range(self):
"""Accessor for global_range.
Returns:
Range: The `global_range` `Range` object.
"""
return self._global_range

@property
def local_range(self):
"""Accessor for local_range.
Returns:
Range: The `local_range` `Range` object.
"""
return self._local_range

def get_global_range(self):
"""Returns a Range defining the index space.
Returns:
Range: A `Range` object defining the index space.
"""
return self._global_range

def get_local_range(self):
"""Returns a Range defining the index space of a work group.
Returns:
Range: A `Range` object to specify index space of a work group.
"""
return self._local_range

def __str__(self):
"""str() function for NdRange class.
Returns:
str: str representation for NdRange class.
"""
return (
"(" + str(self._global_range) + ", " + str(self._local_range) + ")"
)

def __repr__(self):
"""repr() function for NdRange class.
Returns:
str: str representation for NdRange class.
"""
return self.__str__()


if __name__ == "__main__":
r1 = Range(1)
print("r1 =", r1)

r2 = Range(1, 2)
print("r2 =", r2)

r3 = Range(1, 2, 3)
print("r3 =", r3, ", len(r3) =", len(r3))

r3 = Range(*(1, 2, 3))
print("r3 =", r3, ", len(r3) =", len(r3))

r3 = Range(*[1, 2, 3])
print("r3 =", r3, ", len(r3) =", len(r3))

print("r1.get(0) =", r1.get(0))
try:
print("r2.get(2) =", r2.get(2))
except Exception as e:
print(e)

print("r3.get(0) =", r3.get(0))
print("r3.get(1) =", r3.get(1))

print("r1[0] =", r1[0])
try:
print("r2[2] =", r2[2])
except Exception as e:
print(e)

print("r3[0] =", r3[0])
print("r3[1] =", r3[1])

try:
r4 = Range(1, 2, 3, 4)
except Exception as e:
print(e)

try:
r5 = Range(*(1, 2, 3, 4))
except Exception as e:
print(e)

ndr1 = NdRange(Range(1, 2))
print("ndr1 =", ndr1)

ndr2 = NdRange(Range(1, 2), Range(1, 1, 1))
print("ndr2 =", ndr2)

ndr3 = NdRange((1, 2))
print("ndr3 =", ndr3)

ndr4 = NdRange((1, 2), (1, 1, 1))
print("ndr4 =", ndr4)
3 changes: 2 additions & 1 deletion numba_dpex/examples/debug/dpex_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

import numba_dpex as dpex
from numba_dpex.core.kernel_interface.utils import Range


@dpex.func(debug=True)
Expand All @@ -24,7 +25,7 @@ def driver(a, b, c, global_size):
print("a = ", a)
print("b = ", b)
print("c = ", c)
kernel_sum[global_size, dpex.DEFAULT_LOCAL_SIZE](a, b, c)
kernel_sum[Range(global_size)](a, b, c)
print("a + b = ", c)


Expand Down
Loading

0 comments on commit aa4eb5b

Please sign in to comment.