Skip to content

Commit

Permalink
feat(pyspark): support basic caching
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Mar 23, 2023
1 parent fb06262 commit ab0df7a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 77 deletions.
20 changes: 20 additions & 0 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def _from_url(self, url: str) -> Backend:
session = builder.getOrCreate()
return self.connect(session)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cached_dataframes = {}

def do_connect(self, session: SparkSession) -> None:
"""Create a PySpark `Backend` for use with Ibis.
Expand Down Expand Up @@ -554,3 +558,19 @@ def compute_stats(

def has_operation(cls, operation: type[ops.Value]) -> bool:
return operation in PySparkExprTranslator._registry

def _load_into_cache(self, name, expr):
t = expr.compile().cache()
assert t.is_cached
t.createOrReplaceTempView(name)
# store the underlying spark dataframe so we can release memory when
# asked to, instead of when the session ends
self._cached_dataframes[name] = t

def _clean_up_cached_table(self, op):
name = op.name
self._session.catalog.dropTempView(name)
t = self._cached_dataframes.pop(name)
assert t.is_cached
t.unpersist()
assert not t.is_cached
84 changes: 7 additions & 77 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,17 +1075,7 @@ def test_create_table_timestamp(con):
con.drop_table(name, force=True)


@mark.notimpl(
[
"clickhouse",
"datafusion",
"bigquery",
"impala",
"pyspark",
"trino",
"druid",
]
)
@mark.notimpl(["clickhouse", "datafusion", "bigquery", "impala", "trino", "druid"])
@mark.notyet(
["sqlite"], reason="sqlite only support temporary tables in temporary databases"
)
Expand All @@ -1105,17 +1095,7 @@ def test_persist_expression_ref_count(con, alltypes):
assert con._query_cache.refs[op] == 1


@mark.notimpl(
[
"clickhouse",
"datafusion",
"bigquery",
"impala",
"pyspark",
"trino",
"druid",
]
)
@mark.notimpl(["clickhouse", "datafusion", "bigquery", "impala", "trino", "druid"])
@mark.notyet(
["sqlite"], reason="sqlite only support temporary tables in temporary databases"
)
Expand All @@ -1129,17 +1109,7 @@ def test_persist_expression(alltypes):
tm.assert_frame_equal(non_persisted_table.to_pandas(), persisted_table.to_pandas())


@mark.notimpl(
[
"clickhouse",
"datafusion",
"bigquery",
"impala",
"pyspark",
"trino",
"druid",
]
)
@mark.notimpl(["clickhouse", "datafusion", "bigquery", "impala", "trino", "druid"])
@mark.notyet(
["sqlite"], reason="sqlite only support temporary tables in temporary databases"
)
Expand All @@ -1155,17 +1125,7 @@ def test_persist_expression_contextmanager(alltypes):
tm.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas())


@mark.notimpl(
[
"clickhouse",
"datafusion",
"bigquery",
"impala",
"pyspark",
"trino",
"druid",
]
)
@mark.notimpl(["clickhouse", "datafusion", "bigquery", "impala", "trino", "druid"])
@mark.notyet(
["sqlite"], reason="sqlite only support temporary tables in temporary databases"
)
Expand All @@ -1184,17 +1144,7 @@ def test_persist_expression_contextmanager_ref_count(con, alltypes):
assert con._query_cache.refs[op] == 0


@mark.notimpl(
[
"clickhouse",
"datafusion",
"bigquery",
"impala",
"pyspark",
"trino",
"druid",
]
)
@mark.notimpl(["clickhouse", "datafusion", "bigquery", "impala", "trino", "druid"])
@mark.notyet(
["sqlite"], reason="sqlite only support temporary tables in temporary databases"
)
Expand Down Expand Up @@ -1230,17 +1180,7 @@ def test_persist_expression_multiple_refs(con, alltypes):
assert name2 not in con.list_tables()


@mark.notimpl(
[
"clickhouse",
"datafusion",
"bigquery",
"impala",
"pyspark",
"trino",
"druid",
]
)
@mark.notimpl(["clickhouse", "datafusion", "bigquery", "impala", "trino", "druid"])
@mark.notyet(
["sqlite"], reason="sqlite only support temporary tables in temporary databases"
)
Expand All @@ -1257,17 +1197,7 @@ def test_persist_expression_repeated_cache(alltypes):
assert not nested_cached_table.to_pandas().empty


@mark.notimpl(
[
"clickhouse",
"datafusion",
"bigquery",
"impala",
"pyspark",
"trino",
"druid",
]
)
@mark.notimpl(["clickhouse", "datafusion", "bigquery", "impala", "trino", "druid"])
@mark.notyet(
["sqlite"], reason="sqlite only support temporary tables in temporary databases"
)
Expand Down

0 comments on commit ab0df7a

Please sign in to comment.