Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dispatcher/caching rewrite to address performance regression #912

Merged
merged 2 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 7 additions & 67 deletions numba_dpex/core/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

import hashlib
import sys
from abc import ABCMeta, abstractmethod

from numba.core.caching import CacheImpl, IndexDataCacheFile
from numba.core.serialize import dumps

from numba_dpex import config
from numba_dpex.core.types import USMNdArray


def build_key(
argtypes, pyfunc, codegen, backend=None, device_type=None, exec_queue=None
):
"""Constructs a key from python function, context, backend, the device
type and execution queue.

Compute index key for the given argument types and codegen. It includes a
description of the OS, target architecture and hashes of the bytecode for
the function and, if the function has a __closure__, a hash of the
cell_contents.type

Args:
argtypes : A tuple of numba types corresponding to the arguments to the
compiled function.
pyfunc : The Python function that is to be compiled and cached.
codegen (numba.core.codegen.Codegen):
The codegen object found from the target context.
backend (enum, optional): A 'backend_type' enum.
Defaults to None.
device_type (enum, optional): A 'device_type' enum.
Defaults to None.
exec_queue (dpctl._sycl_queue.SyclQueue', optional): A SYCL queue object.

Returns:
tuple: A tuple of return type, argtpes, magic_tuple of codegen
and another tuple of hashcodes from bytecode and cell_contents.
"""

codebytes = pyfunc.__code__.co_code
if pyfunc.__closure__ is not None:
try:
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

argtylist = list(argtypes)
for i, argty in enumerate(argtylist):
if isinstance(argty, USMNdArray):
# Convert the USMNdArray to an abridged type that disregards the
# usm_type, device, queue, address space attributes.
argtylist[i] = (argty.ndim, argty.dtype, argty.layout)

argtypes = tuple(argtylist)

return (
argtypes,
codegen.magic_tuple(),
backend,
device_type,
exec_queue,
(
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
),
)


class _CacheImpl(CacheImpl):
Expand Down Expand Up @@ -475,8 +410,13 @@ def put(self, key, value):
self._name, len(self._lookup), str(key)
)
)
self._lookup[key].value = value
self.get(key)
chudur-budur marked this conversation as resolved.
Show resolved Hide resolved
node = self._lookup[key]
node.value = value

if node is not self._tail:
self._unlink_node(node)
self._append_tail(node)

return

if key in self._evicted:
Expand Down
41 changes: 23 additions & 18 deletions numba_dpex/core/kernel_interface/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numba.core.types import void

from numba_dpex import NdRange, Range, config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.caching import LRUCache, NullCache
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.exceptions import (
ComputeFollowsDataInferenceError,
Expand All @@ -34,6 +34,11 @@
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import (
build_key,
create_func_hash,
strip_usm_metadata,
)


def get_ordered_arg_access_types(pyfunc, access_types):
Expand Down Expand Up @@ -85,6 +90,8 @@ def __init__(
self._global_range = None
self._local_range = None

self._func_hash = create_func_hash(pyfunc)

# caching related attributes
if not config.ENABLE_CACHE:
self._cache = NullCache()
Expand Down Expand Up @@ -151,7 +158,7 @@ def cache(self):
def cache_hits(self):
return self._cache_hits

def _compile_and_cache(self, argtypes, cache):
def _compile_and_cache(self, argtypes, cache, key=None):
"""Helper function to compile the Python function or Numba FunctionIR
object passed to a JitKernel and store it in an internal cache.
"""
Expand All @@ -171,11 +178,13 @@ def _compile_and_cache(self, argtypes, cache):
device_driver_ir_module = kernel.device_driver_ir_module
kernel_module_name = kernel.module_name

key = build_key(
tuple(argtypes),
self.pyfunc,
kernel.target_context.codegen(),
)
if not key:
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = kernel.target_context.codegen().magic_tuple()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I compared with 0.19.0, even this extra hashing on the codegen_magic_tuple may be adding an overhead. Can you try removing this and just adding the kernel to the key like we had in 0.19

key = build_key(
stripped_argtypes, codegen_magic_tuple, self._func_hash
)

cache.put(key, (device_driver_ir_module, kernel_module_name))

return device_driver_ir_module, kernel_module_name
Expand Down Expand Up @@ -604,12 +613,12 @@ def __call__(self, *args):
self.kernel_name, backend, JitKernel._supported_backends
)

# load the kernel from cache
key = build_key(
tuple(argtypes),
self.pyfunc,
dpex_kernel_target.target_context.codegen(),
# Generate key used for cache lookup
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)

# If the JitKernel was specialized then raise exception if argtypes
# do not match one of the specialized versions.
Expand All @@ -630,15 +639,11 @@ def __call__(self, *args):
device_driver_ir_module,
kernel_module_name,
) = self._compile_and_cache(
argtypes=argtypes,
cache=self._cache,
argtypes=argtypes, cache=self._cache, key=key
)

kernel_bundle_key = build_key(
tuple(argtypes),
self.pyfunc,
dpex_kernel_target.target_context.codegen(),
exec_queue=exec_queue,
stripped_argtypes, codegen_magic_tuple, exec_queue, self._func_hash
)

artifact = self._kernel_bundle_cache.get(kernel_bundle_key)
Expand Down
21 changes: 15 additions & 6 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
"""_summary_
"""


from numba.core import sigutils, types
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate

from numba_dpex import config
from numba_dpex.core.caching import LRUCache, NullCache, build_key
from numba_dpex.core.caching import LRUCache, NullCache
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.utils import (
build_key,
create_func_hash,
strip_usm_metadata,
)
from numba_dpex.utils import npytypes_array_to_dpex_array


Expand Down Expand Up @@ -91,6 +95,8 @@ def __init__(self, pyfunc, debug=False, enable_cache=True):
self._debug = debug
self._enable_cache = enable_cache

self._func_hash = create_func_hash(pyfunc)

if not config.ENABLE_CACHE:
self._cache = NullCache()
elif self._enable_cache:
Expand Down Expand Up @@ -132,11 +138,14 @@ def compile(self, args):
dpex_kernel_target.typing_context.resolve_argument_type(arg)
for arg in args
]
key = build_key(
tuple(argtypes),
self._pyfunc,
dpex_kernel_target.target_context.codegen(),

# Generate key used for cache lookup
stripped_argtypes = strip_usm_metadata(argtypes)
codegen_magic_tuple = (
dpex_kernel_target.target_context.codegen().magic_tuple()
)
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)

cres = self._cache.get(key)
if cres is None:
self._cache_hits += 1
Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

from .caching_utils import build_key, create_func_hash, strip_usm_metadata
from .suai_helper import SyclUSMArrayInterface, get_info_from_suai

__all__ = [
"get_info_from_suai",
"SyclUSMArrayInterface",
"create_func_hash",
"strip_usm_metadata",
"build_key",
]
68 changes: 68 additions & 0 deletions numba_dpex/core/utils/caching_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import hashlib

from numba.core.serialize import dumps

from numba_dpex.core.types import USMNdArray


def build_key(*args):
"""Constructs key from variable list of args
Args:
*args: List of components to construct key
Return:
Tuple of args
"""
return tuple(args)


def create_func_hash(pyfunc):
"""Creates a tuple of sha256 hashes out of code and
variable bytes extracted from the compiled funtion.
Args:
pyfunc: Python function object
Return:
Tuple of hashes of code and variable bytes
"""
codebytes = pyfunc.__code__.co_code
if pyfunc.__closure__ is not None:
try:
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
# Note: cloudpickle serializes a function differently depending
# on how the process is launched; e.g. multiprocessing.Process
cvarbytes = dumps(cvars)
except:
cvarbytes = b"" # a temporary solution for function template
else:
cvarbytes = b""

return (
hashlib.sha256(codebytes).hexdigest(),
hashlib.sha256(cvarbytes).hexdigest(),
)


def strip_usm_metadata(argtypes):
"""Convert the USMNdArray to an abridged type that disregards the
usm_type, device, queue, address space attributes.
Args:
argtypes: List of types
Return:
Tuple of types after removing USM metadata from USMNdArray type
"""

stripped_argtypes = []
for argty in argtypes:
if isinstance(argty, USMNdArray):
stripped_argtypes.append((argty.ndim, argty.dtype, argty.layout))
else:
stripped_argtypes.append(argty)

return tuple(stripped_argtypes)