Skip to content

Commit

Permalink
Added examples, test cases and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Jan 19, 2023
1 parent 0f5a240 commit 5e467c5
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 16 deletions.
2 changes: 2 additions & 0 deletions numba_dpex/core/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def __init__(self, name="cache", capacity=10, pyfunc=None):
"""Constructor for LRUCache.
Args:
name (str, optional): The name of the cache, useful for
debugging.
capacity (int, optional): The max capacity of the cache.
Defaults to 10.
pyfunc (NoneType, optional): A python function to be cached.
Expand Down
77 changes: 71 additions & 6 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,31 @@


class DpexFunction(object):
"""Class to materialize dpex function"""

def __init__(self, pyfunc, debug=None):
"""Constructor for DpexFunction
Args:
pyfunc (function): A python function to be compiled.
debug (object, optional): Debug option for compilation.
Defaults to None.
"""
self._pyfunc = pyfunc
self._debug = debug

def compile(self, arg_types, return_types):
"""The actual compilation function.
Args:
arg_types (tuple): Function argument types in a tuple.
return_types (numba.core.types.scalars.Integer):
An integer value to specify the return type.
Returns:
numba.core.compiler.CompileResult: The compiled result
"""

cres = compile_with_dpex(
pyfunc=self._pyfunc,
pyfunc_name=self._pyfunc.__name__,
Expand All @@ -42,6 +62,15 @@ class DpexFunctionTemplate(object):
"""Unmaterialized dpex function"""

def __init__(self, pyfunc, debug=None, enable_cache=True):
"""AI is creating summary for __init__
Args:
pyfunc (function): A python function to be compiled.
debug (object, optional): Debug option for compilation.
Defaults to None.
enable_cache (bool, optional): Flag to turn on/off caching.
Defaults to True.
"""
self._pyfunc = pyfunc
self._debug = debug
self._enable_cache = enable_cache
Expand All @@ -60,19 +89,29 @@ def __init__(self, pyfunc, debug=None, enable_cache=True):

@property
def cache(self):
"""Cache accessor"""
return self._cache

@property
def cache_hits(self):
"""Cache hit count accessor"""
return self._cache_hits

def compile(self, args):
"""Compile a dpex.func decorated Python function with the given
argument types.
"""Compile a dpex.func decorated Python function
Compile a dpex.func decorated Python function with the given
argument types. Each signature is compiled once by caching the
compiled function inside this object.
Args:
args (tuple): Function argument types in a tuple.
Each signature is compiled once by caching the compiled function inside
this object.
Returns:
numba.core.typing.templates.Signature: Signature of the
compiled result.
"""

argtypes = [
dpex_target.typing_context.resolve_argument_type(arg)
for arg in args
Expand Down Expand Up @@ -102,11 +141,21 @@ def compile(self, args):
cres.target_context.insert_user_function(self, cres.fndesc, libs)
# cres.target_context.add_user_function(self, cres.fndesc, libs)
self._cache.put(key, cres)

return cres.signature


def compile_func(pyfunc, signature, debug=None):
"""AI is creating summary for compile_func
Args:
pyfunc (function): A python function to be compiled.
signature (list): A list of numba.core.typing.templates.Signature's
debug (object, optional): Debug options. Defaults to None.
Returns:
numba_dpex.core.kernel_interface.func.DpexFunction: DpexFunction object
"""

devfn = DpexFunction(pyfunc, debug=debug)

cres = []
Expand Down Expand Up @@ -139,7 +188,20 @@ class _function_template(ConcreteTemplate):


def compile_func_template(pyfunc, debug=None):
"""Compile a DpexFunctionTemplate"""
"""Compile a DpexFunctionTemplate
Args:
pyfunc (function): A python function to be compiled.
debug (object, optional): Debug options. Defaults to None.
Raises:
AssertionError: Raised if keyword arguments are supplied in
the inner generic function.
Returns:
numba_dpex.core.kernel_interface.func.DpexFunctionTemplate:
A DpexFunctionTemplate object.
"""

dft = DpexFunctionTemplate(pyfunc, debug=debug)

Expand All @@ -148,10 +210,13 @@ class _function_template(AbstractTemplate):
exact_match_required = True
key = dft

# TODO: Talk with the numba team and see why this has been
# called twice, could be a bug with numba.
def generic(self, args, kws):
if kws:
raise AssertionError("No keyword arguments allowed.")
return dft.compile(args)

dpex_target.typing_context.insert_user_function(dft, _function_template)

return dft
1 change: 1 addition & 0 deletions numba_dpex/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,6 @@ def _wrapped(pyfunc):

return _wrapped
else:
# no signature
func = func_or_sig
return _func_autojit(func, debug=debug)
101 changes: 91 additions & 10 deletions numba_dpex/examples/kernel/device_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,33 @@
import dpnp as np

import numba_dpex as ndpex
from numba_dpex import float32, int32, int64

# Array size
N = 10


# A device callable function that can be invoked from ``kernel`` and other
# device functions
# A device callable function that can be invoked from
# ``kernel`` and other device functions
@ndpex.func
def a_device_function(a):
return a + 1


# A device callable function with signature that can be invoked
# from ``kernel`` and other device functions
@ndpex.func(int32(int32))
def a_device_function_int32(a):
return a + 1


# A device callable function with list signature that can be invoked
# from ``kernel`` and other device functions
@ndpex.func([int32(int32), float32(float32)])
def a_device_function_int32_float32(a):
return a + 1


# A device callable function can call another device function
@ndpex.func
def another_device_function(a):
Expand All @@ -30,24 +45,90 @@ def a_kernel_function(a, b):
b[i] = another_device_function(a[i])


# Utility function for printing
def driver(a, b, N):
# A kernel function that calls the device function
@ndpex.kernel
def a_kernel_function_int32(a, b):
i = ndpex.get_global_id(0)
b[i] = a_device_function_int32(a[i])


# A kernel function that calls the device function
@ndpex.kernel
def a_kernel_function_int32_float32(a, b):
i = ndpex.get_global_id(0)
b[i] = a_device_function_int32_float32(a[i])


# test function 1: tests basic
def test1():
a = np.ones(N)
b = np.ones(N)

print("Using device ...")
print(a.device)

print("A=", a)
a_kernel_function[N](a, b)
print("B=", b)

print("Done...")

# Main function
def main():
a = np.ones(N)
b = np.ones(N)

# test function 2: test device func with signature
def test2():
a = np.ones(N, dtype=np.int32)
b = np.ones(N, dtype=np.int32)

print("Using device ...")
print(a.device)
driver(a, b, N)

print("A=", a)
a_kernel_function_int32[N](a, b)
print("B=", b)

print("Done...")


# test function 3: test device function with list signature
def test3():
a = np.ones(N, dtype=np.int32)
b = np.ones(N, dtype=np.int32)

print("Using device ...")
print(a.device)

print("A=", a)
a_kernel_function_int32_float32[N](a, b)
print("B=", b)

# with a different dtype
a = np.ones(N, dtype=np.float32)
b = np.ones(N, dtype=np.float32)

print("Using device ...")
print(a.device)

print("A=", a)
a_kernel_function_int32_float32[N](a, b)
print("B=", b)

# this will fail, since int64 is not in
# the signature list: [int32(int32), float32(float32)]
a = np.ones(N, dtype=np.int64)
b = np.ones(N, dtype=np.int64)

print("Using device ...")
print(a.device)

print("A=", a)
a_kernel_function_int32_float32[N](a, b)
print("B=", b)

print("Done...")


# main function
if __name__ == "__main__":
main()
test1()
test2()
test3()
92 changes: 92 additions & 0 deletions numba_dpex/tests/kernel_tests/test_device_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import dpnp as np
import numpy

import numba_dpex as ndpex
from numba_dpex import float32, int32, int64

# Array size
N = 10


# A device callable function that can be invoked from
# ``kernel`` and other device functions
@ndpex.func
def a_device_function(a):
return a + 1


# A device callable function with signature that can be invoked
# from ``kernel`` and other device functions
@ndpex.func(int32(int32))
def a_device_function_int32(a):
return a + 1


# A device callable function with list signature that can be invoked
# from ``kernel`` and other device functions
@ndpex.func([int32(int32), float32(float32)])
def a_device_function_int32_float32(a):
return a + 1


# A device callable function can call another device function
@ndpex.func
def another_device_function(a):
return a_device_function(a * 2)


# A kernel function that calls the device function
@ndpex.kernel
def a_kernel_function(a, b):
i = ndpex.get_global_id(0)
b[i] = a_device_function(a[i])


# A kernel function that calls the device function
@ndpex.kernel
def a_kernel_function_nested(a, b):
i = ndpex.get_global_id(0)
b[i] = another_device_function(a[i])


# A kernel function that calls the device function
@ndpex.kernel
def a_kernel_function_int32(a, b):
i = ndpex.get_global_id(0)
b[i] = a_device_function_int32(a[i])


# A kernel function that calls the device function
@ndpex.kernel
def a_kernel_function_int32_float32(a, b):
i = ndpex.get_global_id(0)
b[i] = a_device_function_int32_float32(a[i])


def test_basic():
a = np.ones(N)
b = np.ones(N)

a_kernel_function[N](a, b)

b = np.asnumpy(b)
expected = numpy.ones(N) + 1

assert numpy.array_equal(b, expected)


def test_nested():
a = np.ones(N)
b = np.ones(N)

a_kernel_function_nested[N](a, b)

b = np.asnumpy(b)
expected = numpy.ones(N) * 3

assert numpy.array_equal(b, expected)


if __name__ == "__main__":
test_basic()
test_nested()

0 comments on commit 5e467c5

Please sign in to comment.