diff --git a/numba_dpex/core/caching.py b/numba_dpex/core/caching.py index 445107ddfe..990add9692 100644 --- a/numba_dpex/core/caching.py +++ b/numba_dpex/core/caching.py @@ -10,10 +10,6 @@ from numba_dpex import config -def build_key(*args): - return tuple(args) - - class _CacheImpl(CacheImpl): """Implementation of `CacheImpl` to be used by subclasses of `_Cache`. @@ -414,7 +410,13 @@ def put(self, key, value): self._name, len(self._lookup), str(key) ) ) - self._lookup[key].value = value + node = self._lookup[key] + node.value = value + + if node is not self._tail: + self._unlink_node(node) + self._append_tail(node) + return if key in self._evicted: diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index de5f7babdb..b1bbef7b63 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 -import hashlib from collections.abc import Iterable from inspect import signature from warnings import warn @@ -11,12 +10,11 @@ import dpctl import dpctl.program as dpctl_prog from numba.core import sigutils -from numba.core.serialize import dumps from numba.core.types import Array as NpArrayType from numba.core.types import void from numba_dpex import NdRange, Range, config -from numba_dpex.core.caching import LRUCache, NullCache, build_key +from numba_dpex.core.caching import LRUCache, NullCache from numba_dpex.core.descriptor import dpex_kernel_target from numba_dpex.core.exceptions import ( ComputeFollowsDataInferenceError, @@ -36,6 +34,11 @@ from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel from numba_dpex.core.types import USMNdArray +from numba_dpex.core.utils import ( + build_key, + create_func_hash, + strip_usm_metadata, +) def get_ordered_arg_access_types(pyfunc, access_types): @@ -87,7 +90,7 @@ def __init__( self._global_range = None self._local_range = None - self._func_hash = self._create_func_hash() + self._func_hash = create_func_hash(pyfunc) # caching related attributes if not config.ENABLE_CACHE: @@ -147,28 +150,6 @@ def __init__( self._has_specializations = False self._specialization_cache = NullCache() - def _create_func_hash(self): - """Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion.""" - - codebytes = self.pyfunc.__code__.co_code - if self.pyfunc.__closure__ is not None: - try: - cvars = tuple( - [x.cell_contents for x in self.pyfunc.__closure__] - ) - # Note: cloudpickle serializes a function differently depending - # on how the process is launched; e.g. multiprocessing.Process - cvarbytes = dumps(cvars) - except: - cvarbytes = b"" # a temporary solution for function template - else: - cvarbytes = b"" - - return ( - hashlib.sha256(codebytes).hexdigest(), - hashlib.sha256(cvarbytes).hexdigest(), - ) - @property def cache(self): return self._cache @@ -198,7 +179,7 @@ def _compile_and_cache(self, argtypes, cache, key=None): kernel_module_name = kernel.module_name if not key: - stripped_argtypes = self._strip_usm_metadata(argtypes) + stripped_argtypes = strip_usm_metadata(argtypes) codegen_magic_tuple = kernel.target_context.codegen().magic_tuple() key = build_key( stripped_argtypes, codegen_magic_tuple, self._func_hash @@ -614,20 +595,6 @@ def _check_ranges(self, device): device=device, ) - def _strip_usm_metadata(self, argtypes): - stripped_argtypes = [] - for argty in argtypes: - if isinstance(argty, USMNdArray): - # Convert the USMNdArray to an abridged type that disregards the - # usm_type, device, queue, address space attributes. - stripped_argtypes.append( - (argty.ndim, argty.dtype, argty.layout) - ) - else: - stripped_argtypes.append(argty) - - return tuple(stripped_argtypes) - def __call__(self, *args): """Functor to launch a kernel.""" @@ -647,7 +614,7 @@ def __call__(self, *args): ) # Generate key used for cache lookup - stripped_argtypes = self._strip_usm_metadata(argtypes) + stripped_argtypes = strip_usm_metadata(argtypes) codegen_magic_tuple = ( dpex_kernel_target.target_context.codegen().magic_tuple() ) diff --git a/numba_dpex/core/kernel_interface/func.py b/numba_dpex/core/kernel_interface/func.py index 45941358e5..8537a91742 100644 --- a/numba_dpex/core/kernel_interface/func.py +++ b/numba_dpex/core/kernel_interface/func.py @@ -5,17 +5,18 @@ """_summary_ """ -import hashlib - from numba.core import sigutils, types -from numba.core.serialize import dumps from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate from numba_dpex import config -from numba_dpex.core.caching import LRUCache, NullCache, build_key +from numba_dpex.core.caching import LRUCache, NullCache from numba_dpex.core.compiler import compile_with_dpex from numba_dpex.core.descriptor import dpex_kernel_target -from numba_dpex.core.types import USMNdArray +from numba_dpex.core.utils import ( + build_key, + create_func_hash, + strip_usm_metadata, +) from numba_dpex.utils import npytypes_array_to_dpex_array @@ -94,7 +95,7 @@ def __init__(self, pyfunc, debug=False, enable_cache=True): self._debug = debug self._enable_cache = enable_cache - self._func_hash = self._create_func_hash() + self._func_hash = create_func_hash(pyfunc) if not config.ENABLE_CACHE: self._cache = NullCache() @@ -108,28 +109,6 @@ def __init__(self, pyfunc, debug=False, enable_cache=True): self._cache = NullCache() self._cache_hits = 0 - def _create_func_hash(self): - """Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion.""" - - codebytes = self._pyfunc.__code__.co_code - if self._pyfunc.__closure__ is not None: - try: - cvars = tuple( - [x.cell_contents for x in self._pyfunc.__closure__] - ) - # Note: cloudpickle serializes a function differently depending - # on how the process is launched; e.g. multiprocessing.Process - cvarbytes = dumps(cvars) - except: - cvarbytes = b"" # a temporary solution for function template - else: - cvarbytes = b"" - - return ( - hashlib.sha256(codebytes).hexdigest(), - hashlib.sha256(cvarbytes).hexdigest(), - ) - @property def cache(self): """Cache accessor""" @@ -140,20 +119,6 @@ def cache_hits(self): """Cache hit count accessor""" return self._cache_hits - def _strip_usm_metadata(self, argtypes): - stripped_argtypes = [] - for argty in argtypes: - if isinstance(argty, USMNdArray): - # Convert the USMNdArray to an abridged type that disregards the - # usm_type, device, queue, address space attributes. - stripped_argtypes.append( - (argty.ndim, argty.dtype, argty.layout) - ) - else: - stripped_argtypes.append(argty) - - return tuple(stripped_argtypes) - def compile(self, args): """Compile a `numba_dpex.func` decorated function @@ -175,7 +140,7 @@ def compile(self, args): ] # Generate key used for cache lookup - stripped_argtypes = self._strip_usm_metadata(argtypes) + stripped_argtypes = strip_usm_metadata(argtypes) codegen_magic_tuple = ( dpex_kernel_target.target_context.codegen().magic_tuple() ) diff --git a/numba_dpex/core/utils/__init__.py b/numba_dpex/core/utils/__init__.py index 78bf969d57..e736b48ce6 100644 --- a/numba_dpex/core/utils/__init__.py +++ b/numba_dpex/core/utils/__init__.py @@ -2,9 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from .caching_utils import build_key, create_func_hash, strip_usm_metadata from .suai_helper import SyclUSMArrayInterface, get_info_from_suai __all__ = [ "get_info_from_suai", "SyclUSMArrayInterface", + "create_func_hash", + "strip_usm_metadata", + "build_key", ] diff --git a/numba_dpex/core/utils/caching_utils.py b/numba_dpex/core/utils/caching_utils.py new file mode 100644 index 0000000000..b48f460b9b --- /dev/null +++ b/numba_dpex/core/utils/caching_utils.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import hashlib + +from numba.core.serialize import dumps + +from numba_dpex.core.types import USMNdArray + + +def build_key(*args): + """Constructs key from variable list of args + + Args: + *args: List of components to construct key + Return: + Tuple of args + """ + return tuple(args) + + +def create_func_hash(pyfunc): + """Creates a tuple of sha256 hashes out of code and + variable bytes extracted from the compiled funtion. + + Args: + pyfunc: Python function object + Return: + Tuple of hashes of code and variable bytes + """ + codebytes = pyfunc.__code__.co_code + if pyfunc.__closure__ is not None: + try: + cvars = tuple([x.cell_contents for x in pyfunc.__closure__]) + # Note: cloudpickle serializes a function differently depending + # on how the process is launched; e.g. multiprocessing.Process + cvarbytes = dumps(cvars) + except: + cvarbytes = b"" # a temporary solution for function template + else: + cvarbytes = b"" + + return ( + hashlib.sha256(codebytes).hexdigest(), + hashlib.sha256(cvarbytes).hexdigest(), + ) + + +def strip_usm_metadata(argtypes): + """Convert the USMNdArray to an abridged type that disregards the + usm_type, device, queue, address space attributes. + + Args: + argtypes: List of types + + Return: + Tuple of types after removing USM metadata from USMNdArray type + """ + + stripped_argtypes = [] + for argty in argtypes: + if isinstance(argty, USMNdArray): + stripped_argtypes.append((argty.ndim, argty.dtype, argty.layout)) + else: + stripped_argtypes.append(argty) + + return tuple(stripped_argtypes)