Skip to content

Commit

Permalink
perf(common): improve equality caching by explicitly invalidating the…
Browse files Browse the repository at this point in the history
… entry on `__del__`
  • Loading branch information
kszucs committed Mar 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent becdca9 commit 029dd20
Showing 9 changed files with 93 additions and 104 deletions.
37 changes: 16 additions & 21 deletions ibis/common/bases.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,6 @@
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary

from ibis.common.caching import WeakCache

if TYPE_CHECKING:
from collections.abc import Mapping

@@ -141,41 +139,38 @@ class Comparable(Abstract):
Since the class holds a global cache of comparison results, it is important
to make sure that the instances are not kept alive longer than necessary.
This is done automatically by using weak references for the compared objects.
"""

__cache__ = WeakCache()

def __eq__(self, other) -> bool:
try:
return self.__cached_equals__(other)
except TypeError:
return NotImplemented
__cache__ = {}

@abstractmethod
def __equals__(self, other) -> bool: ...

def __cached_equals__(self, other) -> bool:
def __eq__(self, other) -> bool:
if self is other:
return True

# type comparison should be cheap
if type(self) is not type(other):
return False

# reduce space required for commutative operation
if id(self) < id(other):
key = (self, other)
else:
key = (other, self)

id1 = id(self)
id2 = id(other)
try:
result = self.__cache__[key]
return self.__cache__[id1][id2]
except KeyError:
result = self.__equals__(other)
self.__cache__[key] = result

return result
self.__cache__.setdefault(id1, {})[id2] = result
self.__cache__.setdefault(id2, {})[id1] = result
return result

def __del__(self):
id1 = id(self)
for id2 in self.__cache__.pop(id1, ()):
eqs2 = self.__cache__[id2]
del eqs2[id1]
if not eqs2:
del self.__cache__[id2]


class SlottedMeta(AbstractMeta):
52 changes: 1 addition & 51 deletions ibis/common/caching.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from __future__ import annotations

import functools
import weakref
from collections import Counter, defaultdict
from collections.abc import MutableMapping
from typing import TYPE_CHECKING, Any, Callable
from typing import Any, Callable

from bidict import bidict

from ibis.common.exceptions import IbisError

if TYPE_CHECKING:
from collections.abc import Iterator


def memoize(func: Callable) -> Callable:
"""Memoize a function."""
@@ -31,51 +26,6 @@ def wrapper(*args, **kwargs):
return wrapper


class WeakCache(MutableMapping):
__slots__ = ("_data",)
_data: dict

def __init__(self):
object.__setattr__(self, "_data", {})

def __setattr__(self, name, value):
raise TypeError(f"can't set {name}")

def __len__(self) -> int:
return len(self._data)

def __iter__(self) -> Iterator[Any]:
return iter(self._data)

def __setitem__(self, key, value) -> None:
# construct an alternative representation of the key using the id()
# of the key's components, this prevents infinite recursions
identifiers = tuple(id(item) for item in key)

# create a function which removes the key from the cache
def callback(ref_):
return self._data.pop(identifiers, None)

# create weak references for the key's components with the callback
# to remove the cache entry if any of the key's components gets
# garbage collected
refs = tuple(weakref.ref(item, callback) for item in key)

self._data[identifiers] = (value, refs)

def __getitem__(self, key):
identifiers = tuple(id(item) for item in key)
value, _ = self._data[identifiers]
return value

def __delitem__(self, key):
identifiers = tuple(id(item) for item in key)
del self._data[identifiers]

def __repr__(self):
return f"{self.__class__.__name__}({self._data})"


class RefCountedCache:
"""A cache with reference-counted keys.
1 change: 1 addition & 0 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
AbstractMeta,
Comparable,
Final,
Hashable,
Immutable,
Singleton,
)
84 changes: 59 additions & 25 deletions ibis/common/tests/test_bases.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,6 @@
Singleton,
Slotted,
)
from ibis.common.caching import WeakCache


def test_classes_are_based_on_abstract():
@@ -79,9 +78,19 @@ def __init__(self, a, b):
assert copy.deepcopy(foo) is foo


class Cache(dict):
def setpair(self, a, b, value):
a, b = id(a), id(b)
self.setdefault(a, {})[b] = value
self.setdefault(b, {})[a] = value

def getpair(self, a, b):
return self.get(id(a), {}).get(id(b))


class Node(Comparable):
# override the default cache object
__cache__ = WeakCache()
__cache__ = Cache()
__slots__ = ("name",)
num_equal_calls = 0

@@ -107,14 +116,6 @@ def cache():
assert not cache


def pair(a, b):
# for same ordering with comparable
if id(a) < id(b):
return (a, b)
else:
return (b, a)


def test_comparable_basic(cache):
a = Node(name="a")
b = Node(name="a")
@@ -133,28 +134,48 @@ def test_comparable_caching(cache):
d = Node(name="d")
e = Node(name="e")

cache[pair(a, b)] = True
cache[pair(a, c)] = False
cache[pair(c, d)] = True
cache[pair(b, d)] = False
assert len(cache) == 4
cache.setpair(a, b, True)
cache.setpair(a, c, False)
cache.setpair(c, d, True)
cache.setpair(b, d, False)
expected = {
id(a): {id(b): True, id(c): False},
id(b): {id(a): True, id(d): False},
id(c): {id(a): False, id(d): True},
id(d): {id(c): True, id(b): False},
}
assert cache == expected

assert a == b
assert b == a
assert a != c
assert c != a
assert c == d
assert d == c
assert b != d
assert d != b
assert Node.num_equal_calls == 0
assert cache == expected

# no cache hit
assert pair(a, e) not in cache
assert cache.getpair(a, e) is None
assert a != e
assert cache.getpair(a, e) is False
assert Node.num_equal_calls == 1
assert len(cache) == 5
expected = {
id(a): {id(b): True, id(c): False, id(e): False},
id(b): {id(a): True, id(d): False},
id(c): {id(a): False, id(d): True},
id(d): {id(c): True, id(b): False},
id(e): {id(a): False},
}
assert cache == expected

# run only once
assert e != a
assert Node.num_equal_calls == 1
assert pair(a, e) in cache
assert cache.getpair(a, e) is False
assert cache == expected


def test_comparable_garbage_collection(cache):
@@ -163,16 +184,29 @@ def test_comparable_garbage_collection(cache):
c = Node(name="c")
d = Node(name="d")

cache[pair(a, b)] = True
cache[pair(a, c)] = False
cache[pair(c, d)] = True
cache[pair(b, d)] = False
cache.setpair(a, b, True)
cache.setpair(a, c, False)
cache.setpair(c, d, True)
cache.setpair(b, d, False)

assert weakref.getweakrefcount(a) == 2
assert cache.getpair(a, c) is False
assert cache.getpair(c, d) is True
del c
assert weakref.getweakrefcount(a) == 1
assert cache == {
id(a): {id(b): True},
id(b): {id(a): True, id(d): False},
id(d): {id(b): False},
}

assert cache.getpair(a, b) is True
assert cache.getpair(b, d) is False
del b
assert weakref.getweakrefcount(a) == 0
assert cache == {}

assert a != d
assert cache == {id(a): {id(d): False}, id(d): {id(a): False}}
del a
assert cache == {}


def test_comparable_cache_reuse(cache):
15 changes: 13 additions & 2 deletions ibis/common/tests/test_graph_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional
from typing import Any, Optional

import pytest
from typing_extensions import Self
@@ -19,9 +19,10 @@ class MyNode(Concrete, Node):
d: frozendict[str, int]
e: Optional[Self] = None
f: tuple[Self, ...] = ()
g: Any = None


def generate_node(depth):
def generate_node(depth, g=None):
# generate a nested node object with the given depth
if depth == 0:
return MyNode(10, "20", c=(30, 40), d=frozendict(e=50, f=60))
@@ -32,6 +33,7 @@ def generate_node(depth):
d=frozendict(e=5, f=6),
e=generate_node(0),
f=(generate_node(depth - 1), generate_node(0)),
g=g,
)


@@ -62,3 +64,12 @@ def test_replace_mapping(benchmark):
node = generate_node(500)
subs = {generate_node(1): generate_node(0)}
benchmark(node.replace, subs)


def test_equality_caching(benchmark):
node = generate_node(150)
other = generate_node(150)
assert node == other
assert other == node
assert node is not other
benchmark.pedantic(node.__eq__, args=[other], iterations=100, rounds=200)
2 changes: 0 additions & 2 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
@@ -417,9 +417,7 @@ def __equals__(self, other):
assert a != c
assert c != a
assert a.__equals__(b)
assert a.__cached_equals__(b)
assert not a.__equals__(c)
assert not a.__cached_equals__(c)


def test_maintain_definition_order():
2 changes: 1 addition & 1 deletion ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
@@ -140,7 +140,7 @@ def equals(self, other):
raise TypeError(
f"invalid equality comparison between DataType and {type(other)}"
)
return super().__cached_equals__(other)
return self == other

def cast(self, other, **kwargs):
# TODO(kszucs): remove it or deprecate it?
2 changes: 1 addition & 1 deletion ibis/expr/operations/core.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ def equals(self, other) -> bool:
raise TypeError(
f"invalid equality comparison between Node and {type(other)}"
)
return self.__cached_equals__(other)
return self == other

# Avoid custom repr for performance reasons
__repr__ = object.__repr__
2 changes: 1 addition & 1 deletion ibis/expr/schema.py
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ def equals(self, other: Schema) -> bool:
raise TypeError(
f"invalid equality comparison between Schema and {type(other)}"
)
return self.__cached_equals__(other)
return self == other

@classmethod
def from_tuples(

0 comments on commit 029dd20

Please sign in to comment.