Skip to content

Commit

Permalink
(1)moved function hash computation to kernel and func __init__ from c…
Browse files Browse the repository at this point in the history
…aching.build_key(). This avoids computing hash on every call. (2) moved argtypes list building logic to func.py and dispatcher. Again, avoids list building on every call; (3) Rewrote build_key to take variable args and return tuple. (4) Removed unnecessary call to LRUCache.get() inside LRUCache.put()
  • Loading branch information
adarshyoga committed Feb 15, 2023
1 parent 8d0c8ea commit ab8c9f4
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 85 deletions.
66 changes: 2 additions & 64 deletions numba_dpex/core/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

import hashlib
import sys
from abc import ABCMeta, abstractmethod

from numba.core.caching import CacheImpl, IndexDataCacheFile
from numba.core.serialize import dumps

from numba_dpex import config
from numba_dpex.core.types import USMNdArray


def build_key(
argtypes, pyfunc, codegen, backend=None, device_type=None, exec_queue=None
):
"""Constructs a key from python function, context, backend, the device
type and execution queue.
Compute index key for the given argument types and codegen. It includes a
description of the OS, target architecture and hashes of the bytecode for
the function and, if the function has a __closure__, a hash of the
cell_contents.type
Args:
argtypes : A tuple of numba types corresponding to the arguments to the
compiled function.
pyfunc : The Python function that is to be compiled and cached.
codegen (numba.core.codegen.Codegen):
The codegen object found from the target context.
backend (enum, optional): A 'backend_type' enum.
Defaults to None.
device_type (enum, optional): A 'device_type' enum.
Defaults to None.
exec_queue (dpctl._sycl_queue.SyclQueue', optional): A SYCL queue object.
Returns:
tuple: A tuple of return type, argtpes, magic_tuple of codegen
and another tuple of hashcodes from bytecode and cell_contents.
"""

codebytes = pyfunc.__code__.co_code
if pyfunc.__closure__ is not None:
try:
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

argtylist = list(argtypes)
for i, argty in enumerate(argtylist):
if isinstance(argty, USMNdArray):
# Convert the USMNdArray to an abridged type that disregards the
# usm_type, device, queue, address space attributes.
argtylist[i] = (argty.ndim, argty.dtype, argty.layout)

argtypes = tuple(argtylist)

return (
argtypes,
codegen.magic_tuple(),
backend,
device_type,
exec_queue,
(
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
),
)
def build_key(*args):
return tuple(args)


class _CacheImpl(CacheImpl):
Expand Down Expand Up @@ -476,7 +415,6 @@ def put(self, key, value):
)
)
self._lookup[key].value = value
self.get(key)
return

if key in self._evicted:
Expand Down
72 changes: 55 additions & 17 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
# SPDX-License-Identifier: Apache-2.0


import hashlib
from collections.abc import Iterable
from inspect import signature
from warnings import warn

import dpctl
import dpctl.program as dpctl_prog
from numba.core import sigutils
from numba.core.serialize import dumps
from numba.core.types import Array as NpArrayType
from numba.core.types import void

Expand Down Expand Up @@ -85,6 +87,8 @@ def __init__(
self._global_range = None
self._local_range = None

self._func_hash = self._create_func_hash()

# caching related attributes
if not config.ENABLE_CACHE:
self._cache = NullCache()
Expand Down Expand Up @@ -143,6 +147,28 @@ def __init__(
self._has_specializations = False
self._specialization_cache = NullCache()

def _create_func_hash(self):
"""Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion."""

codebytes = self.pyfunc.__code__.co_code
if self.pyfunc.__closure__ is not None:
try:
cvars = tuple(
[x.cell_contents for x in self.pyfunc.__closure__]
)
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

return (
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
)

@property
def cache(self):
return self._cache
Expand All @@ -151,7 +177,7 @@ def cache(self):
def cache_hits(self):
return self._cache_hits

def _compile_and_cache(self, argtypes, cache):
def _compile_and_cache(self, argtypes, cache, key=None):
"""Helper function to compile the Python function or Numba FunctionIR
object passed to a JitKernel and store it in an internal cache.
"""
Expand All @@ -171,11 +197,13 @@ def _compile_and_cache(self, argtypes, cache):
device_driver_ir_module = kernel.device_driver_ir_module
kernel_module_name = kernel.module_name

key = build_key(
tuple(argtypes),
self.pyfunc,
kernel.target_context.codegen(),
)
if not key:
stripped_argtypes = self._strip_usm_metadata(argtypes)
codegen_magic_tuple = kernel.target_context.codegen().magic_tuple()
key = build_key(
stripped_argtypes, codegen_magic_tuple, self._func_hash
)

cache.put(key, (device_driver_ir_module, kernel_module_name))

return device_driver_ir_module, kernel_module_name
Expand Down Expand Up @@ -586,6 +614,20 @@ def _check_ranges(self, device):
device=device,
)

def _strip_usm_metadata(self, argtypes):
stripped_argtypes = []
for argty in argtypes:
if isinstance(argty, USMNdArray):
# Convert the USMNdArray to an abridged type that disregards the
# usm_type, device, queue, address space attributes.
stripped_argtypes.append(
(argty.ndim, argty.dtype, argty.layout)
)
else:
stripped_argtypes.append(argty)

return tuple(stripped_argtypes)

def __call__(self, *args):
"""Functor to launch a kernel."""

Expand All @@ -604,12 +646,12 @@ def __call__(self, *args):
self.kernel_name, backend, JitKernel._supported_backends
)

# load the kernel from cache
key = build_key(
tuple(argtypes),
self.pyfunc,
dpex_kernel_target.target_context.codegen(),
# Generate key used for cache lookup
stripped_argtypes = self._strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)

# If the JitKernel was specialized then raise exception if argtypes
# do not match one of the specialized versions.
Expand All @@ -630,15 +672,11 @@ def __call__(self, *args):
device_driver_ir_module,
kernel_module_name,
) = self._compile_and_cache(
argtypes=argtypes,
cache=self._cache,
argtypes=argtypes, cache=self._cache, key=key
)

kernel_bundle_key = build_key(
tuple(argtypes),
self.pyfunc,
dpex_kernel_target.target_context.codegen(),
exec_queue=exec_queue,
stripped_argtypes, codegen_magic_tuple, exec_queue, self._func_hash
)

artifact = self._kernel_bundle_cache.get(kernel_bundle_key)
Expand Down
52 changes: 48 additions & 4 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
"""_summary_
"""

import hashlib

from numba.core import sigutils, types
from numba.core.serialize import dumps
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate

from numba_dpex import config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types import USMNdArray
from numba_dpex.utils import npytypes_array_to_dpex_array


Expand Down Expand Up @@ -91,6 +94,8 @@ def __init__(self, pyfunc, debug=False, enable_cache=True):
self._debug = debug
self._enable_cache = enable_cache

self._func_hash = self._create_func_hash()

if not config.ENABLE_CACHE:
self._cache = NullCache()
elif self._enable_cache:
Expand All @@ -103,6 +108,28 @@ def __init__(self, pyfunc, debug=False, enable_cache=True):
self._cache = NullCache()
self._cache_hits = 0

def _create_func_hash(self):
"""Creates a tuple of sha256 hashes out of code and variable bytes extracted from the compiled funtion."""

codebytes = self._pyfunc.__code__.co_code
if self._pyfunc.__closure__ is not None:
try:
cvars = tuple(
[x.cell_contents for x in self._pyfunc.__closure__]
)
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

return (
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
)

@property
def cache(self):
"""Cache accessor"""
Expand All @@ -113,6 +140,20 @@ def cache_hits(self):
"""Cache hit count accessor"""
return self._cache_hits

def _strip_usm_metadata(self, argtypes):
stripped_argtypes = []
for argty in argtypes:
if isinstance(argty, USMNdArray):
# Convert the USMNdArray to an abridged type that disregards the
# usm_type, device, queue, address space attributes.
stripped_argtypes.append(
(argty.ndim, argty.dtype, argty.layout)
)
else:
stripped_argtypes.append(argty)

return tuple(stripped_argtypes)

def compile(self, args):
"""Compile a `numba_dpex.func` decorated function
Expand All @@ -132,11 +173,14 @@ def compile(self, args):
dpex_kernel_target.typing_context.resolve_argument_type(arg)
for arg in args
]
key = build_key(
tuple(argtypes),
self._pyfunc,
dpex_kernel_target.target_context.codegen(),

# Generate key used for cache lookup
stripped_argtypes = self._strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)

cres = self._cache.get(key)
if cres is None:
self._cache_hits += 1
Expand Down

0 comments on commit ab8c9f4

Please sign in to comment.