Skip to content

Commit

Permalink
Fix #159: Pass self to @cachedmethod key function.
Browse files Browse the repository at this point in the history
  • Loading branch information
tkem committed Dec 21, 2021
1 parent d4b1569 commit 749a2c5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
14 changes: 9 additions & 5 deletions src/cachetools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
import random
import time

from .keys import hashkey
from .keys import hashkey as _defaultkey


def _methodkey(_, *args, **kwargs):
return _defaultkey(*args, **kwargs)


class _DefaultSize:
Expand Down Expand Up @@ -615,7 +619,7 @@ def __getitem(self, key):
return value


def cached(cache, key=hashkey, lock=None):
def cached(cache, key=_defaultkey, lock=None):
"""Decorator to wrap a function with a memoizing callable that saves
results in a cache.
Expand Down Expand Up @@ -664,7 +668,7 @@ def wrapper(*args, **kwargs):
return decorator


def cachedmethod(cache, key=hashkey, lock=None):
def cachedmethod(cache, key=_methodkey, lock=None):
"""Decorator to wrap a class or instance method with a memoizing
callable that saves results in a cache.
Expand All @@ -677,7 +681,7 @@ def wrapper(self, *args, **kwargs):
c = cache(self)
if c is None:
return method(self, *args, **kwargs)
k = key(*args, **kwargs)
k = key(self, *args, **kwargs)
try:
return c[k]
except KeyError:
Expand All @@ -695,7 +699,7 @@ def wrapper(self, *args, **kwargs):
c = cache(self)
if c is None:
return method(self, *args, **kwargs)
k = key(*args, **kwargs)
k = key(self, *args, **kwargs)
try:
with lock(self):
return c[k]
Expand Down
30 changes: 26 additions & 4 deletions tests/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ def get_typed(self, value):
self.count += 1
return count

# https://github.com/tkem/cachetools/issues/107
def __hash__(self):
raise TypeError("unhashable type")


class Locked:
def __init__(self, cache):
Expand All @@ -42,6 +38,23 @@ def __exit__(self, *exc):
pass


class Unhashable:
def __init__(self, cache):
self.cache = cache

@cachedmethod(operator.attrgetter("cache"))
def get_default(self, value):
return value

@cachedmethod(operator.attrgetter("cache"), key=keys.hashkey)
def get_hashkey(self, value):
return value

# https://github.com/tkem/cachetools/issues/107
def __hash__(self):
raise TypeError("unhashable type")


class CachedMethodTest(unittest.TestCase):
def test_dict(self):
cached = Cached({})
Expand Down Expand Up @@ -163,3 +176,12 @@ def test_locked_nospace(self):
self.assertEqual(cached.get(1), 5)
self.assertEqual(cached.get(1.0), 7)
self.assertEqual(cached.get(1.0), 9)

def test_unhashable(self):
cached = Unhashable(LRUCache(maxsize=0))

self.assertEqual(cached.get_default(0), 0)
self.assertEqual(cached.get_default(1), 1)

with self.assertRaises(TypeError):
cached.get_hashkey(0)

0 comments on commit 749a2c5

Please sign in to comment.