Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
* Use lru_cache internally for method_cache
* remove cache kwarg from method cache and use fixed cache name based on the instances type
  • Loading branch information
chriseclectic committed Jan 19, 2023
1 parent e889934 commit 9b9b6ba
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 34 deletions.
59 changes: 26 additions & 33 deletions qiskit_experiments/library/tomography/basis/cache_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,30 @@
Method decorator for caching regular methods in class instances.
"""

from typing import Union, Dict, Callable
from typing import Dict, Callable, Optional
import functools


def cache_method(cache: Union[Dict, str] = "_cache") -> Callable:
def _method_cache_name(instance: any) -> str:
"""Attribute name for storing cache in an instance"""
return "_CACHE_" + type(instance).__name__


def _get_method_cache(instance: any) -> Dict:
"""Return instance cache for cached methods"""
cache_name = _method_cache_name(instance)
try:
return getattr(instance, cache_name)
except AttributeError:
setattr(instance, cache_name, {})
return getattr(instance, cache_name)


def cache_method(maxsize: Optional[int] = None) -> Callable:
"""Decorator for caching class instance methods.
Args:
cache: The cache or cache attribute name to use. If a dict it will
be used directly, if a str a cache dict will be created under
that attribute name if one is not already present.
maxsize: The maximum size of this method's LRU cache.
Returns:
The decorator for caching methods.
Expand All @@ -39,36 +52,16 @@ def cache_method_decorator(method: Callable) -> Callable:
The wrapped cached method.
"""

def _cache_key(*args, **kwargs):
return args + tuple(list(kwargs.items()))

if isinstance(cache, str):

def _get_cache(self):
if not hasattr(self, cache):
setattr(self, cache, {})
return getattr(self, cache)

else:

def _get_cache(_):
return cache

@functools.wraps(method)
def _cached_method(self, *args, **kwargs):
_cache = _get_cache(self)

name = method.__name__
if name not in _cache:
_cache[name] = {}
meth_cache = _cache[name]

key = _cache_key(*args, **kwargs)
if key in meth_cache:
return meth_cache[key]
result = method(self, *args, **kwargs)
meth_cache[key] = result
return result
cache = _get_method_cache(self)
key = method.__name__
try:
meth = cache[key]
except KeyError:
meth = cache[key] = functools.lru_cache(maxsize)(functools.partial(method, self))

return meth(*args, **kwargs)

return _cached_method

Expand Down
14 changes: 13 additions & 1 deletion qiskit_experiments/library/tomography/basis/local_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from qiskit.exceptions import QiskitError

from .base_basis import PreparationBasis, MeasurementBasis
from .cache_method import cache_method
from .cache_method import cache_method, _method_cache_name


# Typing
Expand Down Expand Up @@ -229,6 +229,12 @@ def __json_encode__(self):
value["qubit_states"] = self._qubit_states
return value

def __getstate__(self):
# override get state to skip class cache when pickling
state = self.__dict__.copy()
state.pop(_method_cache_name(self), None)
return state


class LocalMeasurementBasis(MeasurementBasis):
"""Local tensor-product measurement basis.
Expand Down Expand Up @@ -487,6 +493,12 @@ def __json_encode__(self):
value["qubit_povms"] = self._qubit_povms
return value

def __getstate__(self):
# override get state to skip class cache when pickling
state = self.__dict__.copy()
state.pop(_method_cache_name(self), None)
return state


def _tensor_product_circuit(
instructions: Sequence[Instruction],
Expand Down

0 comments on commit 9b9b6ba

Please sign in to comment.