Skip to content

Commit

Permalink
Merge pull request #918 from chudur-budur/ref/caching.tester
Browse files Browse the repository at this point in the history
Adding a rigorous test for caching
  • Loading branch information
diptorupd authored Feb 17, 2023
2 parents f991464 + de56ff1 commit a0a8871
Showing 1 changed file with 95 additions and 1 deletion.
96 changes: 95 additions & 1 deletion numba_dpex/tests/kernel_tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,113 @@
#
# SPDX-License-Identifier: Apache-2.0

import dpctl
import string

import dpctl.tensor as dpt
import numpy as np
import pytest

import numba_dpex as dpex
from numba_dpex.core.caching import LRUCache
from numba_dpex.core.kernel_interface.dispatcher import (
JitKernel,
get_ordered_arg_access_types,
)
from numba_dpex.tests._helper import filter_strings


def test_LRUcache_operations():
"""Test rigorous caching operations.
Performs different permutations of caching operations
and check if the state of the cache is correct.
"""
alphabet = list(string.ascii_lowercase)
cache = LRUCache(name="testcache", capacity=4, pyfunc=None)
assert str(cache) == "{}" and cache.head is None and cache.tail is None

states = []
for i in range(4):
cache.put(i, alphabet[i])
tail_key = cache.get(cache.tail.key)
head_key = cache.get(cache.head.key)
states.append((cache, tail_key, head_key))
assert (
str(states)
== "["
+ "({(2: c), (1: b), (3: d), (0: a)}, 'a', 'a'), "
+ "({(2: c), (1: b), (3: d), (0: a)}, 'b', 'a'), "
+ "({(2: c), (1: b), (3: d), (0: a)}, 'c', 'b'), "
+ "({(2: c), (1: b), (3: d), (0: a)}, 'd', 'a')"
+ "]"
)

states = []
picking_order = [3, 1, 0, 2, 2]
for index in picking_order:
value = cache.get(index)
states.append((value, cache, cache.head, cache.tail))
assert (
str(states)
== "["
+ "('d', {(3: d), (1: b), (0: a), (2: c)}, (2: c), (3: d)), "
+ "('b', {(3: d), (1: b), (0: a), (2: c)}, (2: c), (1: b)), "
+ "('a', {(3: d), (1: b), (0: a), (2: c)}, (2: c), (0: a)), "
+ "('c', {(3: d), (1: b), (0: a), (2: c)}, (3: d), (2: c)), "
+ "('c', {(3: d), (1: b), (0: a), (2: c)}, (3: d), (2: c))"
+ "]"
)

states = []
for i in range(5, 10):
cache.put(i, alphabet[i])
tail_key = cache.get(cache.tail.key)
head_key = cache.get(cache.head.key)
states.append((cache, tail_key, head_key))
assert (
str(states)
== "["
+ "({(8: i), (2: c), (9: j), (1: b)}, 'f', 'b'), "
+ "({(8: i), (2: c), (9: j), (1: b)}, 'g', 'c'), "
+ "({(8: i), (2: c), (9: j), (1: b)}, 'h', 'b'), "
+ "({(8: i), (2: c), (9: j), (1: b)}, 'i', 'c'), "
+ "({(8: i), (2: c), (9: j), (1: b)}, 'j', 'b')"
+ "]"
)
assert str(cache.evicted) == "{3: 'd', 0: 'a', 5: 'f', 6: 'g', 7: 'h'}"

picking_order = [2, 1, 3]
states = []
for index in picking_order:
value = cache.get(index)
states.append((value, cache, cache.head, cache.tail))
assert (
str(states)
== "["
+ "('c', {(9: j), (2: c), (1: b), (3: d)}, (8: i), (2: c)), "
+ "('b', {(9: j), (2: c), (1: b), (3: d)}, (8: i), (1: b)), "
+ "('d', {(9: j), (2: c), (1: b), (3: d)}, (9: j), (3: d))"
+ "]"
)
assert str(cache.evicted) == "{0: 'a', 5: 'f', 6: 'g', 7: 'h', 8: 'i'}"

cache.put(0, "x")
assert (
str(cache) == "{(2: c), (1: b), (3: d), (0: x)}"
and str(cache.head) == "(2: c)"
and str(cache.tail) == "(0: x)"
)
assert str(cache.evicted) == "{5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j'}"

cache.put(6, "y")
assert (
str(cache) == "{(1: b), (3: d), (0: x), (6: y)}"
and str(cache.head) == "(1: b)"
and str(cache.tail) == "(6: y)"
)
assert str(cache.evicted) == "{5: 'f', 7: 'h', 8: 'i', 9: 'j', 2: 'c'}"


@pytest.mark.parametrize("filter_str", filter_strings)
def test_caching_hit_counts(filter_str):
"""Tests the correct number of cache hits.
Expand Down

0 comments on commit a0a8871

Please sign in to comment.