From 5e467c5f67c9840a039b35e6823494fc6db09156 Mon Sep 17 00:00:00 2001 From: "akmkhale@ansatnuc04" Date: Thu, 19 Jan 2023 13:08:20 -0600 Subject: [PATCH] Added examples, test cases and docs --- numba_dpex/core/caching.py | 2 + numba_dpex/core/kernel_interface/func.py | 77 +++++++++++-- numba_dpex/decorators.py | 1 + numba_dpex/examples/kernel/device_func.py | 101 ++++++++++++++++-- .../tests/kernel_tests/test_device_func.py | 92 ++++++++++++++++ 5 files changed, 257 insertions(+), 16 deletions(-) create mode 100644 numba_dpex/tests/kernel_tests/test_device_func.py diff --git a/numba_dpex/core/caching.py b/numba_dpex/core/caching.py index 94b8afe111..807703188e 100644 --- a/numba_dpex/core/caching.py +++ b/numba_dpex/core/caching.py @@ -228,6 +228,8 @@ def __init__(self, name="cache", capacity=10, pyfunc=None): """Constructor for LRUCache. Args: + name (str, optional): The name of the cache, useful for + debugging. capacity (int, optional): The max capacity of the cache. Defaults to 10. pyfunc (NoneType, optional): A python function to be cached. diff --git a/numba_dpex/core/kernel_interface/func.py b/numba_dpex/core/kernel_interface/func.py index 7e088adc34..e2f9f0c119 100644 --- a/numba_dpex/core/kernel_interface/func.py +++ b/numba_dpex/core/kernel_interface/func.py @@ -17,11 +17,31 @@ class DpexFunction(object): + """Class to materialize dpex function""" + def __init__(self, pyfunc, debug=None): + """Constructor for DpexFunction + + Args: + pyfunc (function): A python function to be compiled. + debug (object, optional): Debug option for compilation. + Defaults to None. + """ self._pyfunc = pyfunc self._debug = debug def compile(self, arg_types, return_types): + """The actual compilation function. + + Args: + arg_types (tuple): Function argument types in a tuple. + return_types (numba.core.types.scalars.Integer): + An integer value to specify the return type. + + Returns: + numba.core.compiler.CompileResult: The compiled result + """ + cres = compile_with_dpex( pyfunc=self._pyfunc, pyfunc_name=self._pyfunc.__name__, @@ -42,6 +62,15 @@ class DpexFunctionTemplate(object): """Unmaterialized dpex function""" def __init__(self, pyfunc, debug=None, enable_cache=True): + """AI is creating summary for __init__ + + Args: + pyfunc (function): A python function to be compiled. + debug (object, optional): Debug option for compilation. + Defaults to None. + enable_cache (bool, optional): Flag to turn on/off caching. + Defaults to True. + """ self._pyfunc = pyfunc self._debug = debug self._enable_cache = enable_cache @@ -60,19 +89,29 @@ def __init__(self, pyfunc, debug=None, enable_cache=True): @property def cache(self): + """Cache accessor""" return self._cache @property def cache_hits(self): + """Cache hit count accessor""" return self._cache_hits def compile(self, args): - """Compile a dpex.func decorated Python function with the given - argument types. + """Compile a dpex.func decorated Python function + + Compile a dpex.func decorated Python function with the given + argument types. Each signature is compiled once by caching the + compiled function inside this object. + + Args: + args (tuple): Function argument types in a tuple. - Each signature is compiled once by caching the compiled function inside - this object. + Returns: + numba.core.typing.templates.Signature: Signature of the + compiled result. """ + argtypes = [ dpex_target.typing_context.resolve_argument_type(arg) for arg in args @@ -102,11 +141,21 @@ def compile(self, args): cres.target_context.insert_user_function(self, cres.fndesc, libs) # cres.target_context.add_user_function(self, cres.fndesc, libs) self._cache.put(key, cres) - return cres.signature def compile_func(pyfunc, signature, debug=None): + """AI is creating summary for compile_func + + Args: + pyfunc (function): A python function to be compiled. + signature (list): A list of numba.core.typing.templates.Signature's + debug (object, optional): Debug options. Defaults to None. + + Returns: + numba_dpex.core.kernel_interface.func.DpexFunction: DpexFunction object + """ + devfn = DpexFunction(pyfunc, debug=debug) cres = [] @@ -139,7 +188,20 @@ class _function_template(ConcreteTemplate): def compile_func_template(pyfunc, debug=None): - """Compile a DpexFunctionTemplate""" + """Compile a DpexFunctionTemplate + + Args: + pyfunc (function): A python function to be compiled. + debug (object, optional): Debug options. Defaults to None. + + Raises: + AssertionError: Raised if keyword arguments are supplied in + the inner generic function. + + Returns: + numba_dpex.core.kernel_interface.func.DpexFunctionTemplate: + A DpexFunctionTemplate object. + """ dft = DpexFunctionTemplate(pyfunc, debug=debug) @@ -148,10 +210,13 @@ class _function_template(AbstractTemplate): exact_match_required = True key = dft + # TODO: Talk with the numba team and see why this has been + # called twice, could be a bug with numba. def generic(self, args, kws): if kws: raise AssertionError("No keyword arguments allowed.") return dft.compile(args) dpex_target.typing_context.insert_user_function(dft, _function_template) + return dft diff --git a/numba_dpex/decorators.py b/numba_dpex/decorators.py index 0abb0c37c5..d02fad9340 100644 --- a/numba_dpex/decorators.py +++ b/numba_dpex/decorators.py @@ -134,5 +134,6 @@ def _wrapped(pyfunc): return _wrapped else: + # no signature func = func_or_sig return _func_autojit(func, debug=debug) diff --git a/numba_dpex/examples/kernel/device_func.py b/numba_dpex/examples/kernel/device_func.py index 939a79336b..8ead475943 100644 --- a/numba_dpex/examples/kernel/device_func.py +++ b/numba_dpex/examples/kernel/device_func.py @@ -5,18 +5,33 @@ import dpnp as np import numba_dpex as ndpex +from numba_dpex import float32, int32, int64 # Array size N = 10 -# A device callable function that can be invoked from ``kernel`` and other -# device functions +# A device callable function that can be invoked from +# ``kernel`` and other device functions @ndpex.func def a_device_function(a): return a + 1 +# A device callable function with signature that can be invoked +# from ``kernel`` and other device functions +@ndpex.func(int32(int32)) +def a_device_function_int32(a): + return a + 1 + + +# A device callable function with list signature that can be invoked +# from ``kernel`` and other device functions +@ndpex.func([int32(int32), float32(float32)]) +def a_device_function_int32_float32(a): + return a + 1 + + # A device callable function can call another device function @ndpex.func def another_device_function(a): @@ -30,24 +45,90 @@ def a_kernel_function(a, b): b[i] = another_device_function(a[i]) -# Utility function for printing -def driver(a, b, N): +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function_int32(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function_int32(a[i]) + + +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function_int32_float32(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function_int32_float32(a[i]) + + +# test function 1: tests basic +def test1(): + a = np.ones(N) + b = np.ones(N) + + print("Using device ...") + print(a.device) + print("A=", a) a_kernel_function[N](a, b) print("B=", b) + print("Done...") -# Main function -def main(): - a = np.ones(N) - b = np.ones(N) + +# test function 2: test device func with signature +def test2(): + a = np.ones(N, dtype=np.int32) + b = np.ones(N, dtype=np.int32) print("Using device ...") print(a.device) - driver(a, b, N) + + print("A=", a) + a_kernel_function_int32[N](a, b) + print("B=", b) + + print("Done...") + + +# test function 3: test device function with list signature +def test3(): + a = np.ones(N, dtype=np.int32) + b = np.ones(N, dtype=np.int32) + + print("Using device ...") + print(a.device) + + print("A=", a) + a_kernel_function_int32_float32[N](a, b) + print("B=", b) + + # with a different dtype + a = np.ones(N, dtype=np.float32) + b = np.ones(N, dtype=np.float32) + + print("Using device ...") + print(a.device) + + print("A=", a) + a_kernel_function_int32_float32[N](a, b) + print("B=", b) + + # this will fail, since int64 is not in + # the signature list: [int32(int32), float32(float32)] + a = np.ones(N, dtype=np.int64) + b = np.ones(N, dtype=np.int64) + + print("Using device ...") + print(a.device) + + print("A=", a) + a_kernel_function_int32_float32[N](a, b) + print("B=", b) print("Done...") +# main function if __name__ == "__main__": - main() + test1() + test2() + test3() diff --git a/numba_dpex/tests/kernel_tests/test_device_func.py b/numba_dpex/tests/kernel_tests/test_device_func.py new file mode 100644 index 0000000000..965e676ab2 --- /dev/null +++ b/numba_dpex/tests/kernel_tests/test_device_func.py @@ -0,0 +1,92 @@ +import dpnp as np +import numpy + +import numba_dpex as ndpex +from numba_dpex import float32, int32, int64 + +# Array size +N = 10 + + +# A device callable function that can be invoked from +# ``kernel`` and other device functions +@ndpex.func +def a_device_function(a): + return a + 1 + + +# A device callable function with signature that can be invoked +# from ``kernel`` and other device functions +@ndpex.func(int32(int32)) +def a_device_function_int32(a): + return a + 1 + + +# A device callable function with list signature that can be invoked +# from ``kernel`` and other device functions +@ndpex.func([int32(int32), float32(float32)]) +def a_device_function_int32_float32(a): + return a + 1 + + +# A device callable function can call another device function +@ndpex.func +def another_device_function(a): + return a_device_function(a * 2) + + +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function(a[i]) + + +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function_nested(a, b): + i = ndpex.get_global_id(0) + b[i] = another_device_function(a[i]) + + +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function_int32(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function_int32(a[i]) + + +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function_int32_float32(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function_int32_float32(a[i]) + + +def test_basic(): + a = np.ones(N) + b = np.ones(N) + + a_kernel_function[N](a, b) + + b = np.asnumpy(b) + expected = numpy.ones(N) + 1 + + assert numpy.array_equal(b, expected) + + +def test_nested(): + a = np.ones(N) + b = np.ones(N) + + a_kernel_function_nested[N](a, b) + + b = np.asnumpy(b) + expected = numpy.ones(N) * 3 + + assert numpy.array_equal(b, expected) + + +if __name__ == "__main__": + test_basic() + test_nested()