Skip to content

Commit

Permalink
feat(caching): tie lifetime of cached tables to python refs (#9477)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
coroa and cpcloud authored Jul 2, 2024
1 parent 9d7d48f commit f51546e
Show file tree
Hide file tree
Showing 20 changed files with 101 additions and 170 deletions.
1 change: 0 additions & 1 deletion conda/environment-arm64-flink.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ dependencies:
# runtime dependencies
- python =3.10
- atpublic >=2.3
- bidict >=0.22.1
- black >=22.1.0,<25
- clickhouse-connect >=0.5.23
- dask >=2022.9.1
Expand Down
1 change: 0 additions & 1 deletion conda/environment-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ dependencies:
# runtime dependencies
- python >=3.10
- atpublic >=2.3
- bidict >=0.22.1
- black >=22.1.0,<25
- clickhouse-connect >=0.5.23
- dask >=2022.9.1
Expand Down
1 change: 0 additions & 1 deletion conda/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ dependencies:
# runtime dependencies
- apache-flink
- atpublic >=2.3
- bidict >=0.22.1
- black >=22.1.0,<25
- clickhouse-connect >=0.5.23
- dask >=2022.9.1
Expand Down
1 change: 0 additions & 1 deletion docs/posts/run-on-snowflake/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ session.sproc.register(
"snowflake-snowpark-python",
"toolz",
"atpublic",
"bidict",
"pyarrow",
"pandas",
"numpy",
Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,8 +1228,7 @@ def _cached(self, expr: ir.Table):
"""
op = expr.op()
if (result := self._query_cache.get(op)) is None:
self._query_cache.store(expr)
result = self._query_cache[op]
result = self._query_cache.store(expr)
return ir.CachedTable(result)

def _release_cached(self, expr: ir.CachedTable) -> None:
Expand All @@ -1241,12 +1240,12 @@ def _release_cached(self, expr: ir.CachedTable) -> None:
Cached expression to release
"""
del self._query_cache[expr.op()]
self._query_cache.release(expr.op().name)

def _load_into_cache(self, name, expr):
raise NotImplementedError(self.name)

def _clean_up_cached_table(self, op):
def _clean_up_cached_table(self, name):
raise NotImplementedError(self.name)

def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str:
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,10 +1142,11 @@ def drop_view(
def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, op):
def _clean_up_cached_table(self, name):
self.drop_table(
op.name,
name,
database=(self._session_dataset.project, self._session_dataset.dataset_id),
force=True,
)

def _get_udf_source(self, udf_node: ops.ScalarUDF):
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,6 @@ def create_table(

return self.table(name, database=(catalog, database))

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, op):
self.drop_table(op.name)

def table(
self, name: str, schema: str | None = None, database: str | None = None
) -> ir.Table:
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,5 +629,5 @@ def _clean_up_tmp_table(self, name: str) -> None:
with contextlib.suppress(oracledb.DatabaseError):
bind.execute(f'DROP TABLE "{name}"')

def _clean_up_cached_table(self, op):
self._clean_up_tmp_table(op.name)
def _clean_up_cached_table(self, name):
self._clean_up_tmp_table(name)
4 changes: 2 additions & 2 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def _get_operations(cls):
def has_operation(cls, operation: type[ops.Value]) -> bool:
return operation in cls._get_operations()

def _clean_up_cached_table(self, op):
del self.dictionary[op.name]
def _clean_up_cached_table(self, name):
del self.dictionary[name]

def to_pyarrow(
self,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ def to_pyarrow_batches(
def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).cache())

def _clean_up_cached_table(self, op):
self._remove_table(op.name)
def _clean_up_cached_table(self, name):
self._remove_table(name)


@lazy_singledispatch
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,7 @@ def _load_into_cache(self, name, expr):
# asked to, instead of when the session ends
self._cached_dataframes[name] = t

def _clean_up_cached_table(self, op):
name = op.name
def _clean_up_cached_table(self, name):
self._session.catalog.dropTempView(name)
t = self._cached_dataframes.pop(name)
assert t.is_cached
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def ibis_sproc(session):
"snowflake-snowpark-python",
"toolz",
"atpublic",
"bidict",
"pyarrow",
"pandas",
"numpy",
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def drop_view(
def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, op):
self.drop_table(op.name)
def _clean_up_cached_table(self, name):
self.drop_table(name, force=True)

def execute(
self,
Expand Down
101 changes: 38 additions & 63 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,31 +1334,6 @@ def test_create_table_timestamp(con, temp_table):
assert result == schema


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
raises=com.UnsupportedOperationError,
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
)
def test_persist_expression_ref_count(backend, con, alltypes):
non_persisted_table = alltypes.mutate(test_column=ibis.literal("calculation"))
persisted_table = non_persisted_table.cache()

op = non_persisted_table.op()

# ref count is unaffected without a context manager
assert con._query_cache.refs[op] == 1
backend.assert_frame_equal(
non_persisted_table.to_pandas(), persisted_table.to_pandas()
)
assert con._query_cache.refs[op] == 1


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
Expand Down Expand Up @@ -1391,14 +1366,15 @@ def test_persist_expression(backend, alltypes):
raises=com.UnsupportedOperationError,
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
)
def test_persist_expression_contextmanager(backend, alltypes):
def test_persist_expression_contextmanager(backend, con, alltypes):
non_cached_table = alltypes.mutate(
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc")
)
with non_cached_table.cache() as cached_table:
backend.assert_frame_equal(
non_cached_table.to_pandas(), cached_table.to_pandas()
)
assert non_cached_table.op() not in con._query_cache.cache


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
Expand All @@ -1417,12 +1393,12 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes):
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
)
op = non_cached_table.op()
with non_cached_table.cache() as cached_table:
backend.assert_frame_equal(
non_cached_table.to_pandas(), cached_table.to_pandas()
)
assert con._query_cache.refs[op] == 1
assert con._query_cache.refs[op] == 0
cached_table = non_cached_table.cache()
backend.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas())

assert op in con._query_cache.cache
del cached_table
assert op not in con._query_cache.cache


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
Expand All @@ -1441,29 +1417,28 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
)
op = non_cached_table.op()
with non_cached_table.cache() as cached_table:
backend.assert_frame_equal(
non_cached_table.to_pandas(), cached_table.to_pandas()
)
cached_table = non_cached_table.cache()

name1 = cached_table.op().name
backend.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas())

with non_cached_table.cache() as nested_cached_table:
name2 = nested_cached_table.op().name
assert not nested_cached_table.to_pandas().empty
name = cached_table.op().name
nested_cached_table = non_cached_table.cache()

# there are two refs to the uncached expression
assert con._query_cache.refs[op] == 2
# cached tables are identical and reusing the same op
assert cached_table.op() is nested_cached_table.op()
# table is cached
assert op in con._query_cache.cache

# one ref to the uncached expression was removed by the context manager
assert con._query_cache.refs[op] == 1
# deleting the first reference, leaves table in cache
del nested_cached_table
assert op in con._query_cache.cache

# no refs left after the outer context manager exits
assert con._query_cache.refs[op] == 0
# deleting the last reference, releases table from cache
del cached_table
assert op not in con._query_cache.cache

# assert that tables have been dropped
assert name1 not in con.list_tables()
assert name2 not in con.list_tables()
# assert that table has been dropped
assert name not in con.list_tables()


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
Expand All @@ -1477,20 +1452,22 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):
raises=com.UnsupportedOperationError,
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
)
def test_persist_expression_repeated_cache(alltypes):
def test_persist_expression_repeated_cache(alltypes, con):
non_cached_table = alltypes.mutate(
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
)
with non_cached_table.cache() as cached_table:
with cached_table.cache() as nested_cached_table:
assert not nested_cached_table.to_pandas().empty
cached_table = non_cached_table.cache()
nested_cached_table = cached_table.cache()
name = cached_table.op().name

assert not nested_cached_table.to_pandas().empty

del nested_cached_table, cached_table

assert name not in con.list_tables()


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
Expand All @@ -1503,13 +1480,11 @@ def test_persist_expression_release(con, alltypes):
)
cached_table = non_cached_table.cache()
cached_table.release()
assert con._query_cache.refs[non_cached_table.op()] == 0

with pytest.raises(
com.IbisError,
match=r".+Did you call `\.release\(\)` twice on the same expression\?",
):
cached_table.release()
assert non_cached_table.op() not in con._query_cache.cache

# a second release does not hurt
cached_table.release()

with pytest.raises(Exception, match=cached_table.op().name):
cached_table.execute()
Expand Down
Loading

0 comments on commit f51546e

Please sign in to comment.