Skip to content

Commit

Permalink
refactor(cache): factor out ref counted cache
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Mar 22, 2023
1 parent e5df790 commit c816f00
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 64 deletions.
47 changes: 17 additions & 30 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import re
import sys
import urllib.parse
from collections import Counter
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -21,14 +20,13 @@
MutableMapping,
)

from bidict import MutableBidirectionalMapping, bidict

import ibis
import ibis.common.exceptions as exc
import ibis.config
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
from ibis.common.caching import RefCountedCache

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -460,11 +458,13 @@ def __init__(self, *args, **kwargs):
self._con_args: tuple[Any] = args
self._con_kwargs: dict[str, Any] = kwargs
# expression cache
self._query_cache: MutableBidirectionalMapping[
ops.TableNode, ops.PhysicalTable
] = bidict()

self._refs = Counter()
self._query_cache = RefCountedCache(
populate=self._load_into_cache,
lookup=lambda name: self.table(name).op(),
finalize=self._clean_up_cached_table,
generate_name=functools.partial(util.generate_unique_table_name, "cache"),
key=lambda expr: expr.op(),
)

def __getstate__(self):
return dict(
Expand Down Expand Up @@ -902,7 +902,7 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
f"{cls.name} backend has not implemented `has_operation` API"
)

def _cached(self, expr):
def _cached(self, expr: ir.Table):
"""Cache the provided expression.
All subsequent operations on the returned expression will be performed on the cached data.
Expand All @@ -920,38 +920,25 @@ def _cached(self, expr):
"""
op = expr.op()
if (result := self._query_cache.get(op)) is None:
name = util.generate_unique_table_name("cache")
self._load_into_cache(name, expr)
self._query_cache[op] = result = self.table(name).op()
self._refs[op] += 1
self._query_cache.store(expr)
result = self._query_cache[op]
return ir.CachedTable(result)

def _release_cached(self, expr):
def _release_cached(self, expr: ir.CachedTable) -> None:
"""Releases the provided cached expression.
Parameters
----------
expr
Cached expression to release
"""
op = expr.op()
# we need to remove the expression representing the temp table as well
# as the expression that was used to create the temp table
#
# bidict automatically handles this for us; without it we'd have to
# do to the bookkeeping ourselves with two dicts
if (key := self._query_cache.inverse.get(op)) is None:
raise exc.IbisError(
"This expression has already been released. Did you call "
"`.release()` twice on the same expression?"
)
del self._query_cache[expr.op()]

self._refs[key] -= 1
def _load_into_cache(self, name, expr):
raise NotImplementedError(self.name)

if not self._refs[key]:
del self._query_cache[key]
del self._refs[key]
self._clean_up_cached_table(op)
def _clean_up_cached_table(self, op):
raise NotImplementedError(self.name)


@functools.lru_cache(maxsize=None)
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,3 @@ def _convert_object(cls, obj: dd.DataFrame) -> dd.DataFrame:

def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).persist())

def _clean_up_cached_table(self, op):
del self.dictionary[op.name]
11 changes: 5 additions & 6 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
issubclass(operation, op_impl) for op_impl in op_classes
)

def _clean_up_cached_table(self, op):
del self.dictionary[op.name]


class Backend(BasePandasBackend):
name = 'pandas'
Expand Down Expand Up @@ -303,9 +306,5 @@ def execute(self, query, params=None, limit='default', **kwargs):

return execute_and_reset(node, params=params, **kwargs)

def _cached(self, expr):
"""No-op. The expression is already in memory."""
return ir.CachedTable(expr.op())

def _release_cached(self, _):
"""No-op."""
def _load_into_cache(self, name, expr):
self.create_table(name, expr.execute())
46 changes: 22 additions & 24 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,21 +1093,16 @@ def test_create_table_timestamp(con):
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.never(
["pandas"],
raises=AssertionError,
reason="pandas doesn't cache anything and therefore does no ref counting",
)
def test_persist_expression_ref_count(con, alltypes):
non_persisted_table = alltypes.mutate(test_column="calculation")
persisted_table = non_persisted_table.cache()

op = non_persisted_table.op()

# ref count is unaffected without a context manager
assert con._refs[op] == 1
assert con._query_cache.refs[op] == 1
tm.assert_frame_equal(non_persisted_table.to_pandas(), persisted_table.to_pandas())
assert con._refs[op] == 1
assert con._query_cache.refs[op] == 1


@mark.notimpl(
Expand Down Expand Up @@ -1178,20 +1173,15 @@ def test_persist_expression_contextmanager(alltypes):
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.never(
["pandas"],
raises=AssertionError,
reason="pandas doesn't cache anything and therefore does no ref counting",
)
def test_persist_expression_contextmanager_ref_count(con, alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc 2"
)
op = non_cached_table.op()
with non_cached_table.cache() as cached_table:
tm.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas())
assert con._refs[op] == 1
assert con._refs[op] == 0
assert con._query_cache.refs[op] == 1
assert con._query_cache.refs[op] == 0


@mark.notimpl(
Expand All @@ -1212,23 +1202,32 @@ def test_persist_expression_contextmanager_ref_count(con, alltypes):
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.never(
["pandas"],
raises=AssertionError,
reason="pandas doesn't cache anything and therefore does no ref counting",
)
def test_persist_expression_multiple_refs(con, alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc 2"
)
op = non_cached_table.op()
with non_cached_table.cache() as cached_table:
tm.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas())

name1 = cached_table.op().name

with non_cached_table.cache() as nested_cached_table:
name2 = nested_cached_table.op().name
assert not nested_cached_table.to_pandas().empty
assert con._refs[op] == 2
assert con._refs[op] == 1
assert con._refs[op] == 0

# there are two refs to the uncached expression
assert con._query_cache.refs[op] == 2

# one ref to the uncached expression was removed by the context manager
assert con._query_cache.refs[op] == 1

# no refs left after the outer context manager exits
assert con._query_cache.refs[op] == 0

# assert that tables have been dropped
assert name1 not in con.list_tables()
assert name2 not in con.list_tables()


@mark.notimpl(
Expand Down Expand Up @@ -1276,14 +1275,13 @@ def test_persist_expression_repeated_cache(alltypes):
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.never(["pandas"], reason="pandas does not need to release anything")
def test_persist_expression_release(con, alltypes):
non_cached_table = alltypes.mutate(
test_column="calculation", other_column="big calc 3"
)
cached_table = non_cached_table.cache()
cached_table.release()
assert con._refs[non_cached_table.op()] == 0
assert con._query_cache.refs[non_cached_table.op()] == 0

with pytest.raises(
com.IbisError,
Expand Down
79 changes: 78 additions & 1 deletion ibis/common/caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import annotations

import weakref
from typing import MutableMapping
from collections import Counter, defaultdict
from typing import Any, Callable, MutableMapping

import toolz
from bidict import bidict

from ibis.common.exceptions import IbisError


class WeakCache(MutableMapping):
Expand Down Expand Up @@ -46,3 +52,74 @@ def __delitem__(self, key):

def __repr__(self):
return f"{self.__class__.__name__}({self._data})"


class RefCountedCache:
"""A cache with reference-counted keys.
We could implement `MutableMapping`, but the `__setitem__` implementation
doesn't make sense and the `len` and `__iter__` methods aren't used.
We can implement that interface if and when we need to.
"""

def __init__(
self,
*,
populate: Callable[[str, Any], None],
lookup: Callable[[str], Any],
finalize: Callable[[Any], None],
generate_name: Callable[[], str],
key: Callable[[Any], Any],
) -> None:
self.cache = bidict()
self.refs = Counter()
self.populate = populate
self.lookup = lookup
self.finalize = finalize
self.names = defaultdict(generate_name)
self.key = key or toolz.identity

def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default

def __getitem__(self, key):
result = self.cache[key]
self.refs[key] += 1
return result

def store(self, input) -> None:
"""Compute and store a reference to `key`."""
key = self.key(input)
name = self.names[key]
self.populate(name, input)
self.cache[key] = self.lookup(name)
# nothing outside of this instance has referenced this key yet, so the
# refcount is zero
#
# in theory it's possible to call store -> delitem which would raise an
# exception, but in practice this doesn't happen because the only call
# to store is immediately followed by a call to getitem.
self.refs[key] = 0

def __delitem__(self, key) -> None:
# we need to remove the expression representing the computed physical
# table as well as the expression that was used to create that table
#
# bidict automatically handles this for us; without it we'd have to do
# to the bookkeeping ourselves with two dicts
if (inv_key := self.cache.inverse.get(key)) is None:
raise IbisError(
"Key has already been released. Did you call "
"`.release()` twice on the same expression?"
)

self.refs[inv_key] -= 1
assert self.refs[inv_key] >= 0, f"refcount is negative: {self.refs[inv_key]:d}"

if not self.refs[inv_key]:
del self.cache[inv_key], self.refs[inv_key]
self.finalize(key)
6 changes: 6 additions & 0 deletions ibis/tests/expr/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,12 @@ def create_view(self, *_, **__) -> ir.Table:
def drop_view(self, *_, **__) -> ir.Table:
raise NotImplementedError(self.name)

def _load_into_cache(self, *_):
raise NotImplementedError(self.name)

def _clean_up_cached_table(self, _):
raise NotImplementedError(self.name)


def table_from_schema(name, meta, schema, *, database: str | None = None):
# Convert Ibis schema to SQLA table
Expand Down

0 comments on commit c816f00

Please sign in to comment.