From e11c86d7f487efb76252f1c50184e39475587478 Mon Sep 17 00:00:00 2001 From: Adarsh Yoga Date: Wed, 15 Feb 2023 06:13:28 +0000 Subject: [PATCH 1/2] (1)moved function hash computation to kernel and func __init__ from caching.build_key(). This avoids computing hash on every call. (2) moved argtypes list building logic to func.py and dispatcher. Again, avoids list building on every call; (3) Rewrote build_key to take variable args and return tuple. (4) Removed unnecessary call to LRUCache.get() inside LRUCache.put() --- numba_dpex/core/caching.py | 66 +---------------- .../core/kernel_interface/dispatcher.py | 72 ++++++++++++++----- numba_dpex/core/kernel_interface/func.py | 52 ++++++++++++-- 3 files changed, 105 insertions(+), 85 deletions(-) diff --git a/numba_dpex/core/caching.py b/numba_dpex/core/caching.py index ddd7972b17..445107ddfe 100644 --- a/numba_dpex/core/caching.py +++ b/numba_dpex/core/caching.py @@ -2,77 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 -import hashlib import sys from abc import ABCMeta, abstractmethod from numba.core.caching import CacheImpl, IndexDataCacheFile -from numba.core.serialize import dumps from numba_dpex import config -from numba_dpex.core.types import USMNdArray -def build_key( - argtypes, pyfunc, codegen, backend=None, device_type=None, exec_queue=None -): - """Constructs a key from python function, context, backend, the device - type and execution queue. - - Compute index key for the given argument types and codegen. It includes a - description of the OS, target architecture and hashes of the bytecode for - the function and, if the function has a __closure__, a hash of the - cell_contents.type - - Args: - argtypes : A tuple of numba types corresponding to the arguments to the - compiled function. - pyfunc : The Python function that is to be compiled and cached. - codegen (numba.core.codegen.Codegen): - The codegen object found from the target context. - backend (enum, optional): A 'backend_type' enum. - Defaults to None. - device_type (enum, optional): A 'device_type' enum. - Defaults to None. - exec_queue (dpctl._sycl_queue.SyclQueue', optional): A SYCL queue object. - - Returns: - tuple: A tuple of return type, argtpes, magic_tuple of codegen - and another tuple of hashcodes from bytecode and cell_contents. - """ - - codebytes = pyfunc.__code__.co_code - if pyfunc.__closure__ is not None: - try: - cvars = tuple([x.cell_contents for x in pyfunc.__closure__]) - # Note: cloudpickle serializes a function differently depending - # on how the process is launched; e.g. multiprocessing.Process - cvarbytes = dumps(cvars) - except: - cvarbytes = b"" # a temporary solution for function template - else: - cvarbytes = b"" - - argtylist = list(argtypes) - for i, argty in enumerate(argtylist): - if isinstance(argty, USMNdArray): - # Convert the USMNdArray to an abridged type that disregards the - # usm_type, device, queue, address space attributes. - argtylist[i] = (argty.ndim, argty.dtype, argty.layout) - - argtypes = tuple(argtylist) - - return ( - argtypes, - codegen.magic_tuple(), - backend, - device_type, - exec_queue, - ( - hashlib.sha256(codebytes).hexdigest(), - hashlib.sha256(cvarbytes).hexdigest(), - ), - ) +def build_key(*args): + return tuple(args) class _CacheImpl(CacheImpl): @@ -476,7 +415,6 @@ def put(self, key, value): ) ) self._lookup[key].value = value - self.get(key) return if key in self._evicted: diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index cc6b61bcaa..de5f7babdb 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 +import hashlib from collections.abc import Iterable from inspect import signature from warnings import warn @@ -10,6 +11,7 @@ import dpctl import dpctl.program as dpctl_prog from numba.core import sigutils +from numba.core.serialize import dumps from numba.core.types import Array as NpArrayType from numba.core.types import void @@ -85,6 +87,8 @@ def __init__( self._global_range = None self._local_range = None + self._func_hash = self._create_func_hash() + # caching related attributes if not config.ENABLE_CACHE: self._cache = NullCache() @@ -143,6 +147,28 @@ def __init__( self._has_specializations = False self._specialization_cache = NullCache() + def _create_func_hash(self): + """Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion.""" + + codebytes = self.pyfunc.__code__.co_code + if self.pyfunc.__closure__ is not None: + try: + cvars = tuple( + [x.cell_contents for x in self.pyfunc.__closure__] + ) + # Note: cloudpickle serializes a function differently depending + # on how the process is launched; e.g. multiprocessing.Process + cvarbytes = dumps(cvars) + except: + cvarbytes = b"" # a temporary solution for function template + else: + cvarbytes = b"" + + return ( + hashlib.sha256(codebytes).hexdigest(), + hashlib.sha256(cvarbytes).hexdigest(), + ) + @property def cache(self): return self._cache @@ -151,7 +177,7 @@ def cache(self): def cache_hits(self): return self._cache_hits - def _compile_and_cache(self, argtypes, cache): + def _compile_and_cache(self, argtypes, cache, key=None): """Helper function to compile the Python function or Numba FunctionIR object passed to a JitKernel and store it in an internal cache. """ @@ -171,11 +197,13 @@ def _compile_and_cache(self, argtypes, cache): device_driver_ir_module = kernel.device_driver_ir_module kernel_module_name = kernel.module_name - key = build_key( - tuple(argtypes), - self.pyfunc, - kernel.target_context.codegen(), - ) + if not key: + stripped_argtypes = self._strip_usm_metadata(argtypes) + codegen_magic_tuple = kernel.target_context.codegen().magic_tuple() + key = build_key( + stripped_argtypes, codegen_magic_tuple, self._func_hash + ) + cache.put(key, (device_driver_ir_module, kernel_module_name)) return device_driver_ir_module, kernel_module_name @@ -586,6 +614,20 @@ def _check_ranges(self, device): device=device, ) + def _strip_usm_metadata(self, argtypes): + stripped_argtypes = [] + for argty in argtypes: + if isinstance(argty, USMNdArray): + # Convert the USMNdArray to an abridged type that disregards the + # usm_type, device, queue, address space attributes. + stripped_argtypes.append( + (argty.ndim, argty.dtype, argty.layout) + ) + else: + stripped_argtypes.append(argty) + + return tuple(stripped_argtypes) + def __call__(self, *args): """Functor to launch a kernel.""" @@ -604,12 +646,12 @@ def __call__(self, *args): self.kernel_name, backend, JitKernel._supported_backends ) - # load the kernel from cache - key = build_key( - tuple(argtypes), - self.pyfunc, - dpex_kernel_target.target_context.codegen(), + # Generate key used for cache lookup + stripped_argtypes = self._strip_usm_metadata(argtypes) + codegen_magic_tuple = ( + dpex_kernel_target.target_context.codegen().magic_tuple() ) + key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash) # If the JitKernel was specialized then raise exception if argtypes # do not match one of the specialized versions. @@ -630,15 +672,11 @@ def __call__(self, *args): device_driver_ir_module, kernel_module_name, ) = self._compile_and_cache( - argtypes=argtypes, - cache=self._cache, + argtypes=argtypes, cache=self._cache, key=key ) kernel_bundle_key = build_key( - tuple(argtypes), - self.pyfunc, - dpex_kernel_target.target_context.codegen(), - exec_queue=exec_queue, + stripped_argtypes, codegen_magic_tuple, exec_queue, self._func_hash ) artifact = self._kernel_bundle_cache.get(kernel_bundle_key) diff --git a/numba_dpex/core/kernel_interface/func.py b/numba_dpex/core/kernel_interface/func.py index acd9b09bb1..45941358e5 100644 --- a/numba_dpex/core/kernel_interface/func.py +++ b/numba_dpex/core/kernel_interface/func.py @@ -5,14 +5,17 @@ """_summary_ """ +import hashlib from numba.core import sigutils, types +from numba.core.serialize import dumps from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate from numba_dpex import config from numba_dpex.core.caching import LRUCache, NullCache, build_key from numba_dpex.core.compiler import compile_with_dpex from numba_dpex.core.descriptor import dpex_kernel_target +from numba_dpex.core.types import USMNdArray from numba_dpex.utils import npytypes_array_to_dpex_array @@ -91,6 +94,8 @@ def __init__(self, pyfunc, debug=False, enable_cache=True): self._debug = debug self._enable_cache = enable_cache + self._func_hash = self._create_func_hash() + if not config.ENABLE_CACHE: self._cache = NullCache() elif self._enable_cache: @@ -103,6 +108,28 @@ def __init__(self, pyfunc, debug=False, enable_cache=True): self._cache = NullCache() self._cache_hits = 0 + def _create_func_hash(self): + """Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion.""" + + codebytes = self._pyfunc.__code__.co_code + if self._pyfunc.__closure__ is not None: + try: + cvars = tuple( + [x.cell_contents for x in self._pyfunc.__closure__] + ) + # Note: cloudpickle serializes a function differently depending + # on how the process is launched; e.g. multiprocessing.Process + cvarbytes = dumps(cvars) + except: + cvarbytes = b"" # a temporary solution for function template + else: + cvarbytes = b"" + + return ( + hashlib.sha256(codebytes).hexdigest(), + hashlib.sha256(cvarbytes).hexdigest(), + ) + @property def cache(self): """Cache accessor""" @@ -113,6 +140,20 @@ def cache_hits(self): """Cache hit count accessor""" return self._cache_hits + def _strip_usm_metadata(self, argtypes): + stripped_argtypes = [] + for argty in argtypes: + if isinstance(argty, USMNdArray): + # Convert the USMNdArray to an abridged type that disregards the + # usm_type, device, queue, address space attributes. + stripped_argtypes.append( + (argty.ndim, argty.dtype, argty.layout) + ) + else: + stripped_argtypes.append(argty) + + return tuple(stripped_argtypes) + def compile(self, args): """Compile a `numba_dpex.func` decorated function @@ -132,11 +173,14 @@ def compile(self, args): dpex_kernel_target.typing_context.resolve_argument_type(arg) for arg in args ] - key = build_key( - tuple(argtypes), - self._pyfunc, - dpex_kernel_target.target_context.codegen(), + + # Generate key used for cache lookup + stripped_argtypes = self._strip_usm_metadata(argtypes) + codegen_magic_tuple = ( + dpex_kernel_target.target_context.codegen().magic_tuple() ) + key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash) + cres = self._cache.get(key) if cres is None: self._cache_hits += 1 From 1056a58dd14fc5f9f2fc590efa8facd2c0f4d705 Mon Sep 17 00:00:00 2001 From: Adarsh Yoga Date: Wed, 15 Feb 2023 20:53:48 +0000 Subject: [PATCH 2/2] (1) moving function hash creation, build key and USM metadata stripping functions to a separate cache utils. (2) added docstrings. (3) Replaced get() with explicit logic to update list in LRUCache --- numba_dpex/core/caching.py | 12 ++-- .../core/kernel_interface/dispatcher.py | 51 +++----------- numba_dpex/core/kernel_interface/func.py | 51 +++----------- numba_dpex/core/utils/__init__.py | 4 ++ numba_dpex/core/utils/caching_utils.py | 68 +++++++++++++++++++ 5 files changed, 96 insertions(+), 90 deletions(-) create mode 100644 numba_dpex/core/utils/caching_utils.py diff --git a/numba_dpex/core/caching.py b/numba_dpex/core/caching.py index 445107ddfe..990add9692 100644 --- a/numba_dpex/core/caching.py +++ b/numba_dpex/core/caching.py @@ -10,10 +10,6 @@ from numba_dpex import config -def build_key(*args): - return tuple(args) - - class _CacheImpl(CacheImpl): """Implementation of `CacheImpl` to be used by subclasses of `_Cache`. @@ -414,7 +410,13 @@ def put(self, key, value): self._name, len(self._lookup), str(key) ) ) - self._lookup[key].value = value + node = self._lookup[key] + node.value = value + + if node is not self._tail: + self._unlink_node(node) + self._append_tail(node) + return if key in self._evicted: diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index de5f7babdb..b1bbef7b63 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 -import hashlib from collections.abc import Iterable from inspect import signature from warnings import warn @@ -11,12 +10,11 @@ import dpctl import dpctl.program as dpctl_prog from numba.core import sigutils -from numba.core.serialize import dumps from numba.core.types import Array as NpArrayType from numba.core.types import void from numba_dpex import NdRange, Range, config -from numba_dpex.core.caching import LRUCache, NullCache, build_key +from numba_dpex.core.caching import LRUCache, NullCache from numba_dpex.core.descriptor import dpex_kernel_target from numba_dpex.core.exceptions import ( ComputeFollowsDataInferenceError, @@ -36,6 +34,11 @@ from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel from numba_dpex.core.types import USMNdArray +from numba_dpex.core.utils import ( + build_key, + create_func_hash, + strip_usm_metadata, +) def get_ordered_arg_access_types(pyfunc, access_types): @@ -87,7 +90,7 @@ def __init__( self._global_range = None self._local_range = None - self._func_hash = self._create_func_hash() + self._func_hash = create_func_hash(pyfunc) # caching related attributes if not config.ENABLE_CACHE: @@ -147,28 +150,6 @@ def __init__( self._has_specializations = False self._specialization_cache = NullCache() - def _create_func_hash(self): - """Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion.""" - - codebytes = self.pyfunc.__code__.co_code - if self.pyfunc.__closure__ is not None: - try: - cvars = tuple( - [x.cell_contents for x in self.pyfunc.__closure__] - ) - # Note: cloudpickle serializes a function differently depending - # on how the process is launched; e.g. multiprocessing.Process - cvarbytes = dumps(cvars) - except: - cvarbytes = b"" # a temporary solution for function template - else: - cvarbytes = b"" - - return ( - hashlib.sha256(codebytes).hexdigest(), - hashlib.sha256(cvarbytes).hexdigest(), - ) - @property def cache(self): return self._cache @@ -198,7 +179,7 @@ def _compile_and_cache(self, argtypes, cache, key=None): kernel_module_name = kernel.module_name if not key: - stripped_argtypes = self._strip_usm_metadata(argtypes) + stripped_argtypes = strip_usm_metadata(argtypes) codegen_magic_tuple = kernel.target_context.codegen().magic_tuple() key = build_key( stripped_argtypes, codegen_magic_tuple, self._func_hash @@ -614,20 +595,6 @@ def _check_ranges(self, device): device=device, ) - def _strip_usm_metadata(self, argtypes): - stripped_argtypes = [] - for argty in argtypes: - if isinstance(argty, USMNdArray): - # Convert the USMNdArray to an abridged type that disregards the - # usm_type, device, queue, address space attributes. - stripped_argtypes.append( - (argty.ndim, argty.dtype, argty.layout) - ) - else: - stripped_argtypes.append(argty) - - return tuple(stripped_argtypes) - def __call__(self, *args): """Functor to launch a kernel.""" @@ -647,7 +614,7 @@ def __call__(self, *args): ) # Generate key used for cache lookup - stripped_argtypes = self._strip_usm_metadata(argtypes) + stripped_argtypes = strip_usm_metadata(argtypes) codegen_magic_tuple = ( dpex_kernel_target.target_context.codegen().magic_tuple() ) diff --git a/numba_dpex/core/kernel_interface/func.py b/numba_dpex/core/kernel_interface/func.py index 45941358e5..8537a91742 100644 --- a/numba_dpex/core/kernel_interface/func.py +++ b/numba_dpex/core/kernel_interface/func.py @@ -5,17 +5,18 @@ """_summary_ """ -import hashlib - from numba.core import sigutils, types -from numba.core.serialize import dumps from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate from numba_dpex import config -from numba_dpex.core.caching import LRUCache, NullCache, build_key +from numba_dpex.core.caching import LRUCache, NullCache from numba_dpex.core.compiler import compile_with_dpex from numba_dpex.core.descriptor import dpex_kernel_target -from numba_dpex.core.types import USMNdArray +from numba_dpex.core.utils import ( + build_key, + create_func_hash, + strip_usm_metadata, +) from numba_dpex.utils import npytypes_array_to_dpex_array @@ -94,7 +95,7 @@ def __init__(self, pyfunc, debug=False, enable_cache=True): self._debug = debug self._enable_cache = enable_cache - self._func_hash = self._create_func_hash() + self._func_hash = create_func_hash(pyfunc) if not config.ENABLE_CACHE: self._cache = NullCache() @@ -108,28 +109,6 @@ def __init__(self, pyfunc, debug=False, enable_cache=True): self._cache = NullCache() self._cache_hits = 0 - def _create_func_hash(self): - """Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion.""" - - codebytes = self._pyfunc.__code__.co_code - if self._pyfunc.__closure__ is not None: - try: - cvars = tuple( - [x.cell_contents for x in self._pyfunc.__closure__] - ) - # Note: cloudpickle serializes a function differently depending - # on how the process is launched; e.g. multiprocessing.Process - cvarbytes = dumps(cvars) - except: - cvarbytes = b"" # a temporary solution for function template - else: - cvarbytes = b"" - - return ( - hashlib.sha256(codebytes).hexdigest(), - hashlib.sha256(cvarbytes).hexdigest(), - ) - @property def cache(self): """Cache accessor""" @@ -140,20 +119,6 @@ def cache_hits(self): """Cache hit count accessor""" return self._cache_hits - def _strip_usm_metadata(self, argtypes): - stripped_argtypes = [] - for argty in argtypes: - if isinstance(argty, USMNdArray): - # Convert the USMNdArray to an abridged type that disregards the - # usm_type, device, queue, address space attributes. - stripped_argtypes.append( - (argty.ndim, argty.dtype, argty.layout) - ) - else: - stripped_argtypes.append(argty) - - return tuple(stripped_argtypes) - def compile(self, args): """Compile a `numba_dpex.func` decorated function @@ -175,7 +140,7 @@ def compile(self, args): ] # Generate key used for cache lookup - stripped_argtypes = self._strip_usm_metadata(argtypes) + stripped_argtypes = strip_usm_metadata(argtypes) codegen_magic_tuple = ( dpex_kernel_target.target_context.codegen().magic_tuple() ) diff --git a/numba_dpex/core/utils/__init__.py b/numba_dpex/core/utils/__init__.py index 78bf969d57..e736b48ce6 100644 --- a/numba_dpex/core/utils/__init__.py +++ b/numba_dpex/core/utils/__init__.py @@ -2,9 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from .caching_utils import build_key, create_func_hash, strip_usm_metadata from .suai_helper import SyclUSMArrayInterface, get_info_from_suai __all__ = [ "get_info_from_suai", "SyclUSMArrayInterface", + "create_func_hash", + "strip_usm_metadata", + "build_key", ] diff --git a/numba_dpex/core/utils/caching_utils.py b/numba_dpex/core/utils/caching_utils.py new file mode 100644 index 0000000000..b48f460b9b --- /dev/null +++ b/numba_dpex/core/utils/caching_utils.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import hashlib + +from numba.core.serialize import dumps + +from numba_dpex.core.types import USMNdArray + + +def build_key(*args): + """Constructs key from variable list of args + + Args: + *args: List of components to construct key + Return: + Tuple of args + """ + return tuple(args) + + +def create_func_hash(pyfunc): + """Creates a tuple of sha256 hashes out of code and + variable bytes extracted from the compiled funtion. + + Args: + pyfunc: Python function object + Return: + Tuple of hashes of code and variable bytes + """ + codebytes = pyfunc.__code__.co_code + if pyfunc.__closure__ is not None: + try: + cvars = tuple([x.cell_contents for x in pyfunc.__closure__]) + # Note: cloudpickle serializes a function differently depending + # on how the process is launched; e.g. multiprocessing.Process + cvarbytes = dumps(cvars) + except: + cvarbytes = b"" # a temporary solution for function template + else: + cvarbytes = b"" + + return ( + hashlib.sha256(codebytes).hexdigest(), + hashlib.sha256(cvarbytes).hexdigest(), + ) + + +def strip_usm_metadata(argtypes): + """Convert the USMNdArray to an abridged type that disregards the + usm_type, device, queue, address space attributes. + + Args: + argtypes: List of types + + Return: + Tuple of types after removing USM metadata from USMNdArray type + """ + + stripped_argtypes = [] + for argty in argtypes: + if isinstance(argty, USMNdArray): + stripped_argtypes.append((argty.ndim, argty.dtype, argty.layout)) + else: + stripped_argtypes.append(argty) + + return tuple(stripped_argtypes)