From a06f8f086fdc0414a5b1fba075d46e89363fc3da Mon Sep 17 00:00:00 2001 From: "akmkhale@ansatnuc04" Date: Thu, 2 Feb 2023 16:49:38 -0600 Subject: [PATCH] Caching kernel_bundle after create_program_from_spirv() --- .pre-commit-config.yaml | 2 +- .../core/kernel_interface/dispatcher.py | 29 +++++++++++++++---- numba_dpex/tests/test_device_array_args.py | 4 +-- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9663207243..d2139c4d48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: blacken-docs additional_dependencies: [black==22.10] - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index feb6ba6bbc..bd9924e4d1 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -95,6 +95,11 @@ def __init__( capacity=config.CACHE_SIZE, pyfunc=self.pyfunc, ) + self._kernel_bundle_cache = LRUCache( + name="KernelBundleCache", + capacity=config.CACHE_SIZE, + pyfunc=self.pyfunc, + ) else: self._cache = NullCache() self._cache_hits = 0 @@ -587,6 +592,7 @@ def __call__(self, *args): # redundant. We should avoid these checks for the specialized case. exec_queue = self._determine_kernel_launch_queue(args, argtypes) backend = exec_queue.backend + device = exec_queue.sycl_device if exec_queue.backend not in [ dpctl.backend_type.opencl, @@ -626,12 +632,25 @@ def __call__(self, *args): cache=self._cache, ) - # create a sycl::KernelBundle - kernel_bundle = dpctl_prog.create_program_from_spirv( - exec_queue, - device_driver_ir_module, - " ".join(self._create_sycl_kernel_bundle_flags), + kernel_bundle_key = build_key( + tuple(argtypes), + self.pyfunc, + dpex_kernel_target.target_context.codegen(), + backend=backend, + device_type=device.device_type, ) + + kernel_bundle = self._kernel_bundle_cache.get(kernel_bundle_key) + + if kernel_bundle is None: + # create a sycl::KernelBundle + kernel_bundle = dpctl_prog.create_program_from_spirv( + exec_queue, + device_driver_ir_module, + " ".join(self._create_sycl_kernel_bundle_flags), + ) + self._kernel_bundle_cache.put(kernel_bundle_key, kernel_bundle) + # get the sycl::kernel sycl_kernel = kernel_bundle.get_sycl_kernel(kernel_module_name) diff --git a/numba_dpex/tests/test_device_array_args.py b/numba_dpex/tests/test_device_array_args.py index cc50c48854..80f4c19fee 100644 --- a/numba_dpex/tests/test_device_array_args.py +++ b/numba_dpex/tests/test_device_array_args.py @@ -26,7 +26,7 @@ def data_parallel_sum(a, b, c): @skip_no_opencl_cpu -class TestArrayArgsGPU: +class TestArrayArgsCPU: def test_device_array_args_cpu(self): c = np.ones_like(a) @@ -37,7 +37,7 @@ def test_device_array_args_cpu(self): @skip_no_opencl_gpu -class TestArrayArgsCPU: +class TestArrayArgsGPU: def test_device_array_args_gpu(self): c = np.ones_like(a)