From de56ff1659d753c22ac5f7b7b72791a184c85ab4 Mon Sep 17 00:00:00 2001 From: "akmkhale@ansatnuc04" Date: Fri, 17 Feb 2023 10:52:07 -0600 Subject: [PATCH] Adding a rigorous test for caching Breaking strings into 80 cols Rename function --- numba_dpex/tests/kernel_tests/test_caching.py | 96 ++++++++++++++++++- 1 file changed, 95 insertions(+), 1 deletion(-) diff --git a/numba_dpex/tests/kernel_tests/test_caching.py b/numba_dpex/tests/kernel_tests/test_caching.py index 9d54b14479..b7e92838b7 100644 --- a/numba_dpex/tests/kernel_tests/test_caching.py +++ b/numba_dpex/tests/kernel_tests/test_caching.py @@ -2,12 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 -import dpctl +import string + import dpctl.tensor as dpt import numpy as np import pytest import numba_dpex as dpex +from numba_dpex.core.caching import LRUCache from numba_dpex.core.kernel_interface.dispatcher import ( JitKernel, get_ordered_arg_access_types, @@ -15,6 +17,98 @@ from numba_dpex.tests._helper import filter_strings +def test_LRUcache_operations(): + """Test rigorous caching operations. + + Performs different permutations of caching operations + and check if the state of the cache is correct. + """ + alphabet = list(string.ascii_lowercase) + cache = LRUCache(name="testcache", capacity=4, pyfunc=None) + assert str(cache) == "{}" and cache.head is None and cache.tail is None + + states = [] + for i in range(4): + cache.put(i, alphabet[i]) + tail_key = cache.get(cache.tail.key) + head_key = cache.get(cache.head.key) + states.append((cache, tail_key, head_key)) + assert ( + str(states) + == "[" + + "({(2: c), (1: b), (3: d), (0: a)}, 'a', 'a'), " + + "({(2: c), (1: b), (3: d), (0: a)}, 'b', 'a'), " + + "({(2: c), (1: b), (3: d), (0: a)}, 'c', 'b'), " + + "({(2: c), (1: b), (3: d), (0: a)}, 'd', 'a')" + + "]" + ) + + states = [] + picking_order = [3, 1, 0, 2, 2] + for index in picking_order: + value = cache.get(index) + states.append((value, cache, cache.head, cache.tail)) + assert ( + str(states) + == "[" + + "('d', {(3: d), (1: b), (0: a), (2: c)}, (2: c), (3: d)), " + + "('b', {(3: d), (1: b), (0: a), (2: c)}, (2: c), (1: b)), " + + "('a', {(3: d), (1: b), (0: a), (2: c)}, (2: c), (0: a)), " + + "('c', {(3: d), (1: b), (0: a), (2: c)}, (3: d), (2: c)), " + + "('c', {(3: d), (1: b), (0: a), (2: c)}, (3: d), (2: c))" + + "]" + ) + + states = [] + for i in range(5, 10): + cache.put(i, alphabet[i]) + tail_key = cache.get(cache.tail.key) + head_key = cache.get(cache.head.key) + states.append((cache, tail_key, head_key)) + assert ( + str(states) + == "[" + + "({(8: i), (2: c), (9: j), (1: b)}, 'f', 'b'), " + + "({(8: i), (2: c), (9: j), (1: b)}, 'g', 'c'), " + + "({(8: i), (2: c), (9: j), (1: b)}, 'h', 'b'), " + + "({(8: i), (2: c), (9: j), (1: b)}, 'i', 'c'), " + + "({(8: i), (2: c), (9: j), (1: b)}, 'j', 'b')" + + "]" + ) + assert str(cache.evicted) == "{3: 'd', 0: 'a', 5: 'f', 6: 'g', 7: 'h'}" + + picking_order = [2, 1, 3] + states = [] + for index in picking_order: + value = cache.get(index) + states.append((value, cache, cache.head, cache.tail)) + assert ( + str(states) + == "[" + + "('c', {(9: j), (2: c), (1: b), (3: d)}, (8: i), (2: c)), " + + "('b', {(9: j), (2: c), (1: b), (3: d)}, (8: i), (1: b)), " + + "('d', {(9: j), (2: c), (1: b), (3: d)}, (9: j), (3: d))" + + "]" + ) + assert str(cache.evicted) == "{0: 'a', 5: 'f', 6: 'g', 7: 'h', 8: 'i'}" + + cache.put(0, "x") + assert ( + str(cache) == "{(2: c), (1: b), (3: d), (0: x)}" + and str(cache.head) == "(2: c)" + and str(cache.tail) == "(0: x)" + ) + assert str(cache.evicted) == "{5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j'}" + + cache.put(6, "y") + assert ( + str(cache) == "{(1: b), (3: d), (0: x), (6: y)}" + and str(cache.head) == "(1: b)" + and str(cache.tail) == "(6: y)" + ) + assert str(cache.evicted) == "{5: 'f', 7: 'h', 8: 'i', 9: 'j', 2: 'c'}" + + @pytest.mark.parametrize("filter_str", filter_strings) def test_caching_hit_counts(filter_str): """Tests the correct number of cache hits.