Skip to content

Commit

Permalink
perf(common): improve equality caching by explicitly invalidating the…
Browse files Browse the repository at this point in the history
… entry on `__del__` (#8708)

Previously we used a rather complicated mechanism to implement global
equality cache for operation nodes involving tricky weak reference
tracking registering callbacks to invalidate cache entries.

While this has greatly improved the overall performance of ibis
internals we can have a simpler and more lightweight implementation by
storing the equality comparison results in a `dict[dict[object_id,
bool]]` data structure which allows us quick lookups and quick
deletions. The caching is also specialized to a pair of objects in
contrary to the previous `WeakCache` implementation which supported
arbitrary number of key elements requiring multiple iterations over the
key tuple.
  • Loading branch information
kszucs authored Mar 21, 2024
1 parent 3d52904 commit ac86f91
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 104 deletions.
37 changes: 16 additions & 21 deletions ibis/common/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary

from ibis.common.caching import WeakCache

if TYPE_CHECKING:
from collections.abc import Mapping

Expand Down Expand Up @@ -141,41 +139,38 @@ class Comparable(Abstract):
Since the class holds a global cache of comparison results, it is important
to make sure that the instances are not kept alive longer than necessary.
This is done automatically by using weak references for the compared objects.
"""

__cache__ = WeakCache()

def __eq__(self, other) -> bool:
try:
return self.__cached_equals__(other)
except TypeError:
return NotImplemented
__cache__ = {}

@abstractmethod
def __equals__(self, other) -> bool: ...

def __cached_equals__(self, other) -> bool:
def __eq__(self, other) -> bool:
if self is other:
return True

# type comparison should be cheap
if type(self) is not type(other):
return False

# reduce space required for commutative operation
if id(self) < id(other):
key = (self, other)
else:
key = (other, self)

id1 = id(self)
id2 = id(other)
try:
result = self.__cache__[key]
return self.__cache__[id1][id2]
except KeyError:
result = self.__equals__(other)
self.__cache__[key] = result

return result
self.__cache__.setdefault(id1, {})[id2] = result
self.__cache__.setdefault(id2, {})[id1] = result
return result

def __del__(self):
id1 = id(self)
for id2 in self.__cache__.pop(id1, ()):
eqs2 = self.__cache__[id2]
del eqs2[id1]
if not eqs2:
del self.__cache__[id2]


class SlottedMeta(AbstractMeta):
Expand Down
52 changes: 1 addition & 51 deletions ibis/common/caching.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from __future__ import annotations

import functools
import weakref
from collections import Counter, defaultdict
from collections.abc import MutableMapping
from typing import TYPE_CHECKING, Any, Callable
from typing import Any, Callable

from bidict import bidict

from ibis.common.exceptions import IbisError

if TYPE_CHECKING:
from collections.abc import Iterator


def memoize(func: Callable) -> Callable:
"""Memoize a function."""
Expand All @@ -31,51 +26,6 @@ def wrapper(*args, **kwargs):
return wrapper


class WeakCache(MutableMapping):
__slots__ = ("_data",)
_data: dict

def __init__(self):
object.__setattr__(self, "_data", {})

def __setattr__(self, name, value):
raise TypeError(f"can't set {name}")

def __len__(self) -> int:
return len(self._data)

def __iter__(self) -> Iterator[Any]:
return iter(self._data)

def __setitem__(self, key, value) -> None:
# construct an alternative representation of the key using the id()
# of the key's components, this prevents infinite recursions
identifiers = tuple(id(item) for item in key)

# create a function which removes the key from the cache
def callback(ref_):
return self._data.pop(identifiers, None)

# create weak references for the key's components with the callback
# to remove the cache entry if any of the key's components gets
# garbage collected
refs = tuple(weakref.ref(item, callback) for item in key)

self._data[identifiers] = (value, refs)

def __getitem__(self, key):
identifiers = tuple(id(item) for item in key)
value, _ = self._data[identifiers]
return value

def __delitem__(self, key):
identifiers = tuple(id(item) for item in key)
del self._data[identifiers]

def __repr__(self):
return f"{self.__class__.__name__}({self._data})"


class RefCountedCache:
"""A cache with reference-counted keys.
Expand Down
1 change: 1 addition & 0 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AbstractMeta,
Comparable,
Final,
Hashable,
Immutable,
Singleton,
)
Expand Down
84 changes: 59 additions & 25 deletions ibis/common/tests/test_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Singleton,
Slotted,
)
from ibis.common.caching import WeakCache


def test_classes_are_based_on_abstract():
Expand Down Expand Up @@ -79,9 +78,19 @@ def __init__(self, a, b):
assert copy.deepcopy(foo) is foo


class Cache(dict):
def setpair(self, a, b, value):
a, b = id(a), id(b)
self.setdefault(a, {})[b] = value
self.setdefault(b, {})[a] = value

def getpair(self, a, b):
return self.get(id(a), {}).get(id(b))


class Node(Comparable):
# override the default cache object
__cache__ = WeakCache()
__cache__ = Cache()
__slots__ = ("name",)
num_equal_calls = 0

Expand All @@ -107,14 +116,6 @@ def cache():
assert not cache


def pair(a, b):
# for same ordering with comparable
if id(a) < id(b):
return (a, b)
else:
return (b, a)


def test_comparable_basic(cache):
a = Node(name="a")
b = Node(name="a")
Expand All @@ -133,28 +134,48 @@ def test_comparable_caching(cache):
d = Node(name="d")
e = Node(name="e")

cache[pair(a, b)] = True
cache[pair(a, c)] = False
cache[pair(c, d)] = True
cache[pair(b, d)] = False
assert len(cache) == 4
cache.setpair(a, b, True)
cache.setpair(a, c, False)
cache.setpair(c, d, True)
cache.setpair(b, d, False)
expected = {
id(a): {id(b): True, id(c): False},
id(b): {id(a): True, id(d): False},
id(c): {id(a): False, id(d): True},
id(d): {id(c): True, id(b): False},
}
assert cache == expected

assert a == b
assert b == a
assert a != c
assert c != a
assert c == d
assert d == c
assert b != d
assert d != b
assert Node.num_equal_calls == 0
assert cache == expected

# no cache hit
assert pair(a, e) not in cache
assert cache.getpair(a, e) is None
assert a != e
assert cache.getpair(a, e) is False
assert Node.num_equal_calls == 1
assert len(cache) == 5
expected = {
id(a): {id(b): True, id(c): False, id(e): False},
id(b): {id(a): True, id(d): False},
id(c): {id(a): False, id(d): True},
id(d): {id(c): True, id(b): False},
id(e): {id(a): False},
}
assert cache == expected

# run only once
assert e != a
assert Node.num_equal_calls == 1
assert pair(a, e) in cache
assert cache.getpair(a, e) is False
assert cache == expected


def test_comparable_garbage_collection(cache):
Expand All @@ -163,16 +184,29 @@ def test_comparable_garbage_collection(cache):
c = Node(name="c")
d = Node(name="d")

cache[pair(a, b)] = True
cache[pair(a, c)] = False
cache[pair(c, d)] = True
cache[pair(b, d)] = False
cache.setpair(a, b, True)
cache.setpair(a, c, False)
cache.setpair(c, d, True)
cache.setpair(b, d, False)

assert weakref.getweakrefcount(a) == 2
assert cache.getpair(a, c) is False
assert cache.getpair(c, d) is True
del c
assert weakref.getweakrefcount(a) == 1
assert cache == {
id(a): {id(b): True},
id(b): {id(a): True, id(d): False},
id(d): {id(b): False},
}

assert cache.getpair(a, b) is True
assert cache.getpair(b, d) is False
del b
assert weakref.getweakrefcount(a) == 0
assert cache == {}

assert a != d
assert cache == {id(a): {id(d): False}, id(d): {id(a): False}}
del a
assert cache == {}


def test_comparable_cache_reuse(cache):
Expand Down
15 changes: 13 additions & 2 deletions ibis/common/tests/test_graph_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional
from typing import Any, Optional

import pytest
from typing_extensions import Self
Expand All @@ -19,9 +19,10 @@ class MyNode(Concrete, Node):
d: frozendict[str, int]
e: Optional[Self] = None
f: tuple[Self, ...] = ()
g: Any = None


def generate_node(depth):
def generate_node(depth, g=None):
# generate a nested node object with the given depth
if depth == 0:
return MyNode(10, "20", c=(30, 40), d=frozendict(e=50, f=60))
Expand All @@ -32,6 +33,7 @@ def generate_node(depth):
d=frozendict(e=5, f=6),
e=generate_node(0),
f=(generate_node(depth - 1), generate_node(0)),
g=g,
)


Expand Down Expand Up @@ -62,3 +64,12 @@ def test_replace_mapping(benchmark):
node = generate_node(500)
subs = {generate_node(1): generate_node(0)}
benchmark(node.replace, subs)


def test_equality_caching(benchmark):
node = generate_node(150)
other = generate_node(150)
assert node == other
assert other == node
assert node is not other
benchmark.pedantic(node.__eq__, args=[other], iterations=100, rounds=200)
2 changes: 0 additions & 2 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,7 @@ def __equals__(self, other):
assert a != c
assert c != a
assert a.__equals__(b)
assert a.__cached_equals__(b)
assert not a.__equals__(c)
assert not a.__cached_equals__(c)


def test_maintain_definition_order():
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def equals(self, other):
raise TypeError(
f"invalid equality comparison between DataType and {type(other)}"
)
return super().__cached_equals__(other)
return self == other

def cast(self, other, **kwargs):
# TODO(kszucs): remove it or deprecate it?
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def equals(self, other) -> bool:
raise TypeError(
f"invalid equality comparison between Node and {type(other)}"
)
return self.__cached_equals__(other)
return self == other

# Avoid custom repr for performance reasons
__repr__ = object.__repr__
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def equals(self, other: Schema) -> bool:
raise TypeError(
f"invalid equality comparison between Schema and {type(other)}"
)
return self.__cached_equals__(other)
return self == other

@classmethod
def from_tuples(
Expand Down

0 comments on commit ac86f91

Please sign in to comment.