From 749a2c578b36cccf2c6c3366d3bfe2dd803142e5 Mon Sep 17 00:00:00 2001 From: Thomas Kemmer Date: Sun, 19 Dec 2021 21:45:34 +0100 Subject: [PATCH] Fix #159: Pass self to @cachedmethod key function. --- src/cachetools/__init__.py | 14 +++++++++----- tests/test_method.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/cachetools/__init__.py b/src/cachetools/__init__.py index d8b4922..478186c 100644 --- a/src/cachetools/__init__.py +++ b/src/cachetools/__init__.py @@ -22,7 +22,11 @@ import random import time -from .keys import hashkey +from .keys import hashkey as _defaultkey + + +def _methodkey(_, *args, **kwargs): + return _defaultkey(*args, **kwargs) class _DefaultSize: @@ -615,7 +619,7 @@ def __getitem(self, key): return value -def cached(cache, key=hashkey, lock=None): +def cached(cache, key=_defaultkey, lock=None): """Decorator to wrap a function with a memoizing callable that saves results in a cache. @@ -664,7 +668,7 @@ def wrapper(*args, **kwargs): return decorator -def cachedmethod(cache, key=hashkey, lock=None): +def cachedmethod(cache, key=_methodkey, lock=None): """Decorator to wrap a class or instance method with a memoizing callable that saves results in a cache. @@ -677,7 +681,7 @@ def wrapper(self, *args, **kwargs): c = cache(self) if c is None: return method(self, *args, **kwargs) - k = key(*args, **kwargs) + k = key(self, *args, **kwargs) try: return c[k] except KeyError: @@ -695,7 +699,7 @@ def wrapper(self, *args, **kwargs): c = cache(self) if c is None: return method(self, *args, **kwargs) - k = key(*args, **kwargs) + k = key(self, *args, **kwargs) try: with lock(self): return c[k] diff --git a/tests/test_method.py b/tests/test_method.py index b41dac0..44dfcaf 100644 --- a/tests/test_method.py +++ b/tests/test_method.py @@ -21,10 +21,6 @@ def get_typed(self, value): self.count += 1 return count - # https://github.com/tkem/cachetools/issues/107 - def __hash__(self): - raise TypeError("unhashable type") - class Locked: def __init__(self, cache): @@ -42,6 +38,23 @@ def __exit__(self, *exc): pass +class Unhashable: + def __init__(self, cache): + self.cache = cache + + @cachedmethod(operator.attrgetter("cache")) + def get_default(self, value): + return value + + @cachedmethod(operator.attrgetter("cache"), key=keys.hashkey) + def get_hashkey(self, value): + return value + + # https://github.com/tkem/cachetools/issues/107 + def __hash__(self): + raise TypeError("unhashable type") + + class CachedMethodTest(unittest.TestCase): def test_dict(self): cached = Cached({}) @@ -163,3 +176,12 @@ def test_locked_nospace(self): self.assertEqual(cached.get(1), 5) self.assertEqual(cached.get(1.0), 7) self.assertEqual(cached.get(1.0), 9) + + def test_unhashable(self): + cached = Unhashable(LRUCache(maxsize=0)) + + self.assertEqual(cached.get_default(0), 0) + self.assertEqual(cached.get_default(1), 1) + + with self.assertRaises(TypeError): + cached.get_hashkey(0)