From f51546ec2473ca8ad1ef6e962b0652648b037ff4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=B6rsch?= Date: Tue, 2 Jul 2024 12:19:54 +0200 Subject: [PATCH] feat(caching): tie lifetime of cached tables to python refs (#9477) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- conda/environment-arm64-flink.yml | 1 - conda/environment-arm64.yml | 1 - conda/environment.yml | 1 - docs/posts/run-on-snowflake/index.qmd | 1 - ibis/backends/__init__.py | 7 +- ibis/backends/bigquery/__init__.py | 5 +- ibis/backends/duckdb/__init__.py | 6 -- ibis/backends/oracle/__init__.py | 4 +- ibis/backends/pandas/__init__.py | 4 +- ibis/backends/polars/__init__.py | 4 +- ibis/backends/pyspark/__init__.py | 3 +- ibis/backends/snowflake/tests/test_udf.py | 1 - ibis/backends/sql/__init__.py | 4 +- ibis/backends/tests/test_client.py | 101 ++++++++-------------- ibis/common/caching.py | 90 +++++++++---------- ibis/examples/pixi.lock | 17 ---- ibis/expr/types/relations.py | 6 +- poetry.lock | 13 +-- pyproject.toml | 1 - requirements-dev.txt | 1 - 20 files changed, 101 insertions(+), 170 deletions(-) diff --git a/conda/environment-arm64-flink.yml b/conda/environment-arm64-flink.yml index 88c5b46f2416..477e215e0bf3 100644 --- a/conda/environment-arm64-flink.yml +++ b/conda/environment-arm64-flink.yml @@ -5,7 +5,6 @@ dependencies: # runtime dependencies - python =3.10 - atpublic >=2.3 - - bidict >=0.22.1 - black >=22.1.0,<25 - clickhouse-connect >=0.5.23 - dask >=2022.9.1 diff --git a/conda/environment-arm64.yml b/conda/environment-arm64.yml index 3bcd8edf7333..9b733687aaf3 100644 --- a/conda/environment-arm64.yml +++ b/conda/environment-arm64.yml @@ -5,7 +5,6 @@ dependencies: # runtime dependencies - python >=3.10 - atpublic >=2.3 - - bidict >=0.22.1 - black >=22.1.0,<25 - clickhouse-connect >=0.5.23 - dask >=2022.9.1 diff --git a/conda/environment.yml b/conda/environment.yml index b6beba79a937..ee14f8e7d5b4 100644 --- a/conda/environment.yml +++ b/conda/environment.yml @@ -5,7 +5,6 @@ dependencies: # runtime dependencies - apache-flink - atpublic >=2.3 - - bidict >=0.22.1 - black >=22.1.0,<25 - clickhouse-connect >=0.5.23 - dask >=2022.9.1 diff --git a/docs/posts/run-on-snowflake/index.qmd b/docs/posts/run-on-snowflake/index.qmd index 4b240c0d7bde..1afebe5a0007 100644 --- a/docs/posts/run-on-snowflake/index.qmd +++ b/docs/posts/run-on-snowflake/index.qmd @@ -111,7 +111,6 @@ session.sproc.register( "snowflake-snowpark-python", "toolz", "atpublic", - "bidict", "pyarrow", "pandas", "numpy", diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index b91b97ae6367..86987a2a207c 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -1228,8 +1228,7 @@ def _cached(self, expr: ir.Table): """ op = expr.op() if (result := self._query_cache.get(op)) is None: - self._query_cache.store(expr) - result = self._query_cache[op] + result = self._query_cache.store(expr) return ir.CachedTable(result) def _release_cached(self, expr: ir.CachedTable) -> None: @@ -1241,12 +1240,12 @@ def _release_cached(self, expr: ir.CachedTable) -> None: Cached expression to release """ - del self._query_cache[expr.op()] + self._query_cache.release(expr.op().name) def _load_into_cache(self, name, expr): raise NotImplementedError(self.name) - def _clean_up_cached_table(self, op): + def _clean_up_cached_table(self, name): raise NotImplementedError(self.name) def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str: diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 1861d8da25f3..fb2a0c50188a 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -1142,10 +1142,11 @@ def drop_view( def _load_into_cache(self, name, expr): self.create_table(name, expr, schema=expr.schema(), temp=True) - def _clean_up_cached_table(self, op): + def _clean_up_cached_table(self, name): self.drop_table( - op.name, + name, database=(self._session_dataset.project, self._session_dataset.dataset_id), + force=True, ) def _get_udf_source(self, udf_node: ops.ScalarUDF): diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 94f005ececc6..3b8bc0667fb0 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -278,12 +278,6 @@ def create_table( return self.table(name, database=(catalog, database)) - def _load_into_cache(self, name, expr): - self.create_table(name, expr, schema=expr.schema(), temp=True) - - def _clean_up_cached_table(self, op): - self.drop_table(op.name) - def table( self, name: str, schema: str | None = None, database: str | None = None ) -> ir.Table: diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index ce4f34765e1c..d2cf6ddda922 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -629,5 +629,5 @@ def _clean_up_tmp_table(self, name: str) -> None: with contextlib.suppress(oracledb.DatabaseError): bind.execute(f'DROP TABLE "{name}"') - def _clean_up_cached_table(self, op): - self._clean_up_tmp_table(op.name) + def _clean_up_cached_table(self, name): + self._clean_up_tmp_table(name) diff --git a/ibis/backends/pandas/__init__.py b/ibis/backends/pandas/__init__.py index 646bb0c49764..0e63d8d19a88 100644 --- a/ibis/backends/pandas/__init__.py +++ b/ibis/backends/pandas/__init__.py @@ -275,8 +275,8 @@ def _get_operations(cls): def has_operation(cls, operation: type[ops.Value]) -> bool: return operation in cls._get_operations() - def _clean_up_cached_table(self, op): - del self.dictionary[op.name] + def _clean_up_cached_table(self, name): + del self.dictionary[name] def to_pyarrow( self, diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 72e80dc7ae7a..9f2290ba8414 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -560,8 +560,8 @@ def to_pyarrow_batches( def _load_into_cache(self, name, expr): self.create_table(name, self.compile(expr).cache()) - def _clean_up_cached_table(self, op): - self._remove_table(op.name) + def _clean_up_cached_table(self, name): + self._remove_table(name) @lazy_singledispatch diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 85507d48a24b..ee4f74f6a29f 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -625,8 +625,7 @@ def _load_into_cache(self, name, expr): # asked to, instead of when the session ends self._cached_dataframes[name] = t - def _clean_up_cached_table(self, op): - name = op.name + def _clean_up_cached_table(self, name): self._session.catalog.dropTempView(name) t = self._cached_dataframes.pop(name) assert t.is_cached diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index d0a26b5953e0..4a59013cebec 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -276,7 +276,6 @@ def ibis_sproc(session): "snowflake-snowpark-python", "toolz", "atpublic", - "bidict", "pyarrow", "pandas", "numpy", diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index fa99048e1f56..e2d4389c8661 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -275,8 +275,8 @@ def drop_view( def _load_into_cache(self, name, expr): self.create_table(name, expr, schema=expr.schema(), temp=True) - def _clean_up_cached_table(self, op): - self.drop_table(op.name) + def _clean_up_cached_table(self, name): + self.drop_table(name, force=True) def execute( self, diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index f3dc7ba00d59..3348c8634c74 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1334,31 +1334,6 @@ def test_create_table_timestamp(con, temp_table): assert result == schema -@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) -@mark.never( - ["mssql"], - reason="mssql supports support temporary tables through naming conventions", -) -@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") -@pytest.mark.never( - ["risingwave"], - raises=com.UnsupportedOperationError, - reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", -) -def test_persist_expression_ref_count(backend, con, alltypes): - non_persisted_table = alltypes.mutate(test_column=ibis.literal("calculation")) - persisted_table = non_persisted_table.cache() - - op = non_persisted_table.op() - - # ref count is unaffected without a context manager - assert con._query_cache.refs[op] == 1 - backend.assert_frame_equal( - non_persisted_table.to_pandas(), persisted_table.to_pandas() - ) - assert con._query_cache.refs[op] == 1 - - @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) @mark.never( ["mssql"], @@ -1391,7 +1366,7 @@ def test_persist_expression(backend, alltypes): raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) -def test_persist_expression_contextmanager(backend, alltypes): +def test_persist_expression_contextmanager(backend, con, alltypes): non_cached_table = alltypes.mutate( test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc") ) @@ -1399,6 +1374,7 @@ def test_persist_expression_contextmanager(backend, alltypes): backend.assert_frame_equal( non_cached_table.to_pandas(), cached_table.to_pandas() ) + assert non_cached_table.op() not in con._query_cache.cache @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) @@ -1417,12 +1393,12 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2") ) op = non_cached_table.op() - with non_cached_table.cache() as cached_table: - backend.assert_frame_equal( - non_cached_table.to_pandas(), cached_table.to_pandas() - ) - assert con._query_cache.refs[op] == 1 - assert con._query_cache.refs[op] == 0 + cached_table = non_cached_table.cache() + backend.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas()) + + assert op in con._query_cache.cache + del cached_table + assert op not in con._query_cache.cache @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) @@ -1441,29 +1417,28 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2") ) op = non_cached_table.op() - with non_cached_table.cache() as cached_table: - backend.assert_frame_equal( - non_cached_table.to_pandas(), cached_table.to_pandas() - ) + cached_table = non_cached_table.cache() - name1 = cached_table.op().name + backend.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas()) - with non_cached_table.cache() as nested_cached_table: - name2 = nested_cached_table.op().name - assert not nested_cached_table.to_pandas().empty + name = cached_table.op().name + nested_cached_table = non_cached_table.cache() - # there are two refs to the uncached expression - assert con._query_cache.refs[op] == 2 + # cached tables are identical and reusing the same op + assert cached_table.op() is nested_cached_table.op() + # table is cached + assert op in con._query_cache.cache - # one ref to the uncached expression was removed by the context manager - assert con._query_cache.refs[op] == 1 + # deleting the first reference, leaves table in cache + del nested_cached_table + assert op in con._query_cache.cache - # no refs left after the outer context manager exits - assert con._query_cache.refs[op] == 0 + # deleting the last reference, releases table from cache + del cached_table + assert op not in con._query_cache.cache - # assert that tables have been dropped - assert name1 not in con.list_tables() - assert name2 not in con.list_tables() + # assert that table has been dropped + assert name not in con.list_tables() @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) @@ -1477,20 +1452,22 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) -def test_persist_expression_repeated_cache(alltypes): +def test_persist_expression_repeated_cache(alltypes, con): non_cached_table = alltypes.mutate( test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2") ) - with non_cached_table.cache() as cached_table: - with cached_table.cache() as nested_cached_table: - assert not nested_cached_table.to_pandas().empty + cached_table = non_cached_table.cache() + nested_cached_table = cached_table.cache() + name = cached_table.op().name + + assert not nested_cached_table.to_pandas().empty + + del nested_cached_table, cached_table + + assert name not in con.list_tables() @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) -@mark.never( - ["mssql"], - reason="mssql supports support temporary tables through naming conventions", -) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], @@ -1503,13 +1480,11 @@ def test_persist_expression_release(con, alltypes): ) cached_table = non_cached_table.cache() cached_table.release() - assert con._query_cache.refs[non_cached_table.op()] == 0 - with pytest.raises( - com.IbisError, - match=r".+Did you call `\.release\(\)` twice on the same expression\?", - ): - cached_table.release() + assert non_cached_table.op() not in con._query_cache.cache + + # a second release does not hurt + cached_table.release() with pytest.raises(Exception, match=cached_table.op().name): cached_table.execute() diff --git a/ibis/common/caching.py b/ibis/common/caching.py index d2c26fc81b03..0ba453a04ff9 100644 --- a/ibis/common/caching.py +++ b/ibis/common/caching.py @@ -1,12 +1,10 @@ from __future__ import annotations import functools -from collections import Counter, defaultdict +import sys +from collections import namedtuple from typing import TYPE_CHECKING, Any - -from bidict import bidict - -from ibis.common.exceptions import IbisError +from weakref import finalize, ref if TYPE_CHECKING: from collections.abc import Callable @@ -29,8 +27,11 @@ def wrapper(*args, **kwargs): return wrapper +CacheEntry = namedtuple("CacheEntry", ["name", "ref", "finalizer"]) + + class RefCountedCache: - """A cache with reference-counted keys. + """A cache with implicitly reference-counted values. We could implement `MutableMapping`, but the `__setitem__` implementation doesn't make sense and the `len` and `__iter__` methods aren't used. @@ -47,56 +48,51 @@ def __init__( generate_name: Callable[[], str], key: Callable[[Any], Any], ) -> None: - self.cache = bidict() - # Somehow mypy needs a type hint here - self.refs: Counter = Counter() self.populate = populate self.lookup = lookup self.finalize = finalize - # Somehow mypy needs a type hint here - self.names: defaultdict = defaultdict(generate_name) + self.generate_name = generate_name self.key = key or (lambda x: x) + self.cache: dict[Any, CacheEntry] = dict() + def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default + if (entry := self.cache.get(key)) is not None: + op = entry.ref() + return op if op is not None else default + return default def __getitem__(self, key): - result = self.cache[key] - self.refs[key] += 1 - return result + op = self.cache[key].ref() + if op is None: + raise KeyError(key) + return op - def store(self, input) -> None: + def store(self, input): """Compute and store a reference to `key`.""" key = self.key(input) - name = self.names[key] + name = self.generate_name() self.populate(name, input) - self.cache[key] = self.lookup(name) - # nothing outside of this instance has referenced this key yet, so the - # refcount is zero - # - # in theory it's possible to call store -> delitem which would raise an - # exception, but in practice this doesn't happen because the only call - # to store is immediately followed by a call to getitem. - self.refs[key] = 0 - - def __delitem__(self, key) -> None: - # we need to remove the expression representing the computed physical - # table as well as the expression that was used to create that table - # - # bidict automatically handles this for us; without it we'd have to do - # to the bookkeeping ourselves with two dicts - if (inv_key := self.cache.inverse.get(key)) is None: - raise IbisError( - "Key has already been released. Did you call " - "`.release()` twice on the same expression?" - ) - - self.refs[inv_key] -= 1 - assert self.refs[inv_key] >= 0, f"refcount is negative: {self.refs[inv_key]:d}" - - if not self.refs[inv_key]: - del self.cache[inv_key], self.refs[inv_key] - self.finalize(key) + cached = self.lookup(name) + finalizer = finalize(cached, self._release, key) + + self.cache[key] = CacheEntry(name, ref(cached), finalizer) + + return cached + + def release(self, name: str) -> None: + # Could be sped up with an inverse dictionary + for key, entry in self.cache.items(): + if entry.name == name: + self._release(key) + return + + def _release(self, key) -> None: + entry = self.cache.pop(key) + try: + self.finalize(entry.name) + except Exception: + # suppress exceptions during interpreter shutdown + if not sys.is_finalizing(): + raise + entry.finalizer.detach() diff --git a/ibis/examples/pixi.lock b/ibis/examples/pixi.lock index 6b53f2570137..f0e34e5ee86c 100644 --- a/ibis/examples/pixi.lock +++ b/ibis/examples/pixi.lock @@ -27,7 +27,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/aws-checksums-0.1.18-h4466546_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/aws-crt-cpp-0.26.3-h137ae52_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/aws-sdk-cpp-1.11.267-he0cb598_3.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/bidict-0.23.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/binutils_impl_linux-64-2.40-hf600244_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/blinker-1.7.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py311hb755f60_1.conda @@ -744,21 +743,6 @@ packages: license_family: Apache size: 3636564 timestamp: 1710322529863 -- kind: conda - name: bidict - version: 0.23.1 - build: pyhd8ed1ab_0 - subdir: noarch - noarch: python - url: https://conda.anaconda.org/conda-forge/noarch/bidict-0.23.1-pyhd8ed1ab_0.conda - sha256: cc7af340b9c99fb170032e103546329cd7c280d28ceca74468e19dd8b0539531 - md5: c9916c3975b19f470218f415701d6362 - depends: - - python >=3.8 - license: MPL-2.0 - license_family: MOZILLA - size: 31346 - timestamp: 1708298386480 - kind: conda name: binutils_impl_linux-64 version: '2.40' @@ -1752,7 +1736,6 @@ packages: md5: a7b3fc66c6248d2e6617337d025a6877 depends: - atpublic >=2.3 - - bidict >=0.22.1 - filelock >=3.7.0,<4 - multipledispatch >=0.6,<2 - numpy >=1.15,<2 diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 78cc2e4a6663..f577d4177d80 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -3552,9 +3552,11 @@ def cache(self) -> Table: """Cache the provided expression. All subsequent operations on the returned expression will be performed - on the cached data. Use the + on the cached data. The lifetime of the cached table is tied to its + python references (ie. it is released once the last reference to it is + garbage collected). Alternatively, use the [`with`](https://docs.python.org/3/reference/compound_stmts.html#with) - statement to limit the lifetime of a cached table. + statement or call the `.release()` method for more control. This method is idempotent: calling it multiple times in succession will return the same value as the first call. diff --git a/poetry.lock b/poetry.lock index cfb9bb327891..7007315b056e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -423,17 +423,6 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] -[[package]] -name = "bidict" -version = "0.23.1" -description = "The bidirectional mapping library for Python." -optional = false -python-versions = ">=3.8" -files = [ - {file = "bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5"}, - {file = "bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71"}, -] - [[package]] name = "bitarray" version = "2.9.2" @@ -7682,4 +7671,4 @@ visualization = ["graphviz"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e538ac00f32af8c9c3040f30542c2eada46299907a787f36faa79898439d5bd5" +content-hash = "1fac34606e813b0339a74f308dfc2b1c15c41186e7f782bca53760e2911fda2b" diff --git a/pyproject.toml b/pyproject.toml index 2034339aae6d..3bd34d6af06b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.10" atpublic = ">=2.3,<5" -bidict = ">=0.22.1,<1" numpy = ">=1.23.2,<3" pandas = ">=1.5.3,<3" parsy = ">=2,<3" diff --git a/requirements-dev.txt b/requirements-dev.txt index 9e02a13c76e8..a85d4407aeea 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,7 +18,6 @@ attrs==23.2.0 ; python_version >= "3.10" and python_version < "4.0" babel==2.15.0 ; python_version >= "3.10" and python_version < "3.13" beartype==0.18.5 ; python_version >= "3.10" and python_version < "3.13" beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "3.13" -bidict==0.23.1 ; python_version >= "3.10" and python_version < "4.0" bitarray==2.9.2 ; python_version >= "3.10" and python_version < "4.0" black==24.4.2 ; python_version >= "3.10" and python_version < "4.0" bleach==6.1.0 ; python_version >= "3.10" and python_version < "3.13"