Skip to content

Commit

Permalink
(1) moving function hash creation, build key and USM metadata strippi…
Browse files Browse the repository at this point in the history
…ng functions to a separate cache utils. (2) added docstrings. (3) Replaced get() with explicit logic to update list in LRUCache
  • Loading branch information
adarshyoga committed Feb 15, 2023
1 parent ab8c9f4 commit 66d3a82
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 90 deletions.
12 changes: 7 additions & 5 deletions numba_dpex/core/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from numba_dpex import config


def build_key(*args):
return tuple(args)


class _CacheImpl(CacheImpl):
"""Implementation of `CacheImpl` to be used by subclasses of `_Cache`.
Expand Down Expand Up @@ -414,7 +410,13 @@ def put(self, key, value):
self._name, len(self._lookup), str(key)
)
)
self._lookup[key].value = value
node = self._lookup[key]
node.value = value

if node is not self._tail:
self._unlink_node(node)
self._append_tail(node)

return

if key in self._evicted:
Expand Down
51 changes: 9 additions & 42 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
# SPDX-License-Identifier: Apache-2.0


import hashlib
from collections.abc import Iterable
from inspect import signature
from warnings import warn

import dpctl
import dpctl.program as dpctl_prog
from numba.core import sigutils
from numba.core.serialize import dumps
from numba.core.types import Array as NpArrayType
from numba.core.types import void

from numba_dpex import NdRange, Range, config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.caching import LRUCache, NullCache
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.exceptions import (
ComputeFollowsDataInferenceError,
Expand All @@ -36,6 +34,11 @@
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import (
build_key,
create_func_hash,
strip_usm_metadata,
)


def get_ordered_arg_access_types(pyfunc, access_types):
Expand Down Expand Up @@ -87,7 +90,7 @@ def __init__(
self._global_range = None
self._local_range = None

self._func_hash = self._create_func_hash()
self._func_hash = create_func_hash(pyfunc)

# caching related attributes
if not config.ENABLE_CACHE:
Expand Down Expand Up @@ -147,28 +150,6 @@ def __init__(
self._has_specializations = False
self._specialization_cache = NullCache()

def _create_func_hash(self):
"""Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion."""

codebytes = self.pyfunc.__code__.co_code
if self.pyfunc.__closure__ is not None:
try:
cvars = tuple(
[x.cell_contents for x in self.pyfunc.__closure__]
)
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

return (
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
)

@property
def cache(self):
return self._cache
Expand Down Expand Up @@ -198,7 +179,7 @@ def _compile_and_cache(self, argtypes, cache, key=None):
kernel_module_name = kernel.module_name

if not key:
stripped_argtypes = self._strip_usm_metadata(argtypes)
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = kernel.target_context.codegen().magic_tuple()
key = build_key(
stripped_argtypes, codegen_magic_tuple, self._func_hash
Expand Down Expand Up @@ -614,20 +595,6 @@ def _check_ranges(self, device):
device=device,
)

def _strip_usm_metadata(self, argtypes):
stripped_argtypes = []
for argty in argtypes:
if isinstance(argty, USMNdArray):
# Convert the USMNdArray to an abridged type that disregards the
# usm_type, device, queue, address space attributes.
stripped_argtypes.append(
(argty.ndim, argty.dtype, argty.layout)
)
else:
stripped_argtypes.append(argty)

return tuple(stripped_argtypes)

def __call__(self, *args):
"""Functor to launch a kernel."""

Expand All @@ -647,7 +614,7 @@ def __call__(self, *args):
)

# Generate key used for cache lookup
stripped_argtypes = self._strip_usm_metadata(argtypes)
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
Expand Down
51 changes: 8 additions & 43 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
"""_summary_
"""

import hashlib

from numba.core import sigutils, types
from numba.core.serialize import dumps
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate

from numba_dpex import config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.caching import LRUCache, NullCache
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import (
build_key,
create_func_hash,
strip_usm_metadata,
)
from numba_dpex.utils import npytypes_array_to_dpex_array


Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(self, pyfunc, debug=False, enable_cache=True):
self._debug = debug
self._enable_cache = enable_cache

self._func_hash = self._create_func_hash()
self._func_hash = create_func_hash(pyfunc)

if not config.ENABLE_CACHE:
self._cache = NullCache()
Expand All @@ -108,28 +109,6 @@ def __init__(self, pyfunc, debug=False, enable_cache=True):
self._cache = NullCache()
self._cache_hits = 0

def _create_func_hash(self):
"""Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion."""

codebytes = self._pyfunc.__code__.co_code
if self._pyfunc.__closure__ is not None:
try:
cvars = tuple(
[x.cell_contents for x in self._pyfunc.__closure__]
)
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

return (
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
)

@property
def cache(self):
"""Cache accessor"""
Expand All @@ -140,20 +119,6 @@ def cache_hits(self):
"""Cache hit count accessor"""
return self._cache_hits

def _strip_usm_metadata(self, argtypes):
stripped_argtypes = []
for argty in argtypes:
if isinstance(argty, USMNdArray):
# Convert the USMNdArray to an abridged type that disregards the
# usm_type, device, queue, address space attributes.
stripped_argtypes.append(
(argty.ndim, argty.dtype, argty.layout)
)
else:
stripped_argtypes.append(argty)

return tuple(stripped_argtypes)

def compile(self, args):
"""Compile a `numba_dpex.func` decorated function
Expand All @@ -175,7 +140,7 @@ def compile(self, args):
]

# Generate key used for cache lookup
stripped_argtypes = self._strip_usm_metadata(argtypes)
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

from .caching_utils import build_key, create_func_hash, strip_usm_metadata
from .suai_helper import SyclUSMArrayInterface, get_info_from_suai

__all__ = [
"get_info_from_suai",
"SyclUSMArrayInterface",
"create_func_hash",
"strip_usm_metadata",
"build_key",
]
68 changes: 68 additions & 0 deletions numba_dpex/core/utils/caching_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import hashlib

from numba.core.serialize import dumps

from numba_dpex.core.types import USMNdArray


def build_key(*args):
"""Constructs key from variable list of args
Args:
*args: List of components to construct key
Return:
Tuple of args
"""
return tuple(args)


def create_func_hash(pyfunc):
"""Creates a tuple of sha256 hashes out of code and
variable bytes extracted from the compiled funtion.
Args:
pyfunc: Python function object
Return:
Tuple of hashes of code and variable bytes
"""
codebytes = pyfunc.__code__.co_code
if pyfunc.__closure__ is not None:
try:
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

return (
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
)


def strip_usm_metadata(argtypes):
"""Convert the USMNdArray to an abridged type that disregards the
usm_type, device, queue, address space attributes.
Args:
argtypes: List of types
Return:
Tuple of types after removing USM metadata from USMNdArray type
"""

stripped_argtypes = []
for argty in argtypes:
if isinstance(argty, USMNdArray):
stripped_argtypes.append((argty.ndim, argty.dtype, argty.layout))
else:
stripped_argtypes.append(argty)

return tuple(stripped_argtypes)

0 comments on commit 66d3a82

Please sign in to comment.