Skip to content

Commit

Permalink
feat(pyspark): add catalog support to pyspark (#9042)
Browse files Browse the repository at this point in the history
Resolves #9038 

Adds support for specifying the `catalog` in various `pyspark` calls.  

BREAKING CHANGE: Arguments to `create_database`, `drop_database`, and `get_schema` are now keyword-only except for the `name` args.  Calls to these functions that have relied on positional argument ordering need to be updated.
  • Loading branch information
gforsyth authored Apr 25, 2024
1 parent 02c6607 commit 2c1a58e
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 27 deletions.
102 changes: 78 additions & 24 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pyspark
import sqlglot as sg
import sqlglot.expressions as sge
from packaging.version import parse as vparse
from pyspark import SparkConf
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import PandasUDFType, pandas_udf
Expand All @@ -19,7 +20,7 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateDatabase
from ibis.backends import CanCreateDatabase, CanListCatalog
from ibis.backends.pyspark.compiler import PySparkCompiler
from ibis.backends.pyspark.converter import PySparkPandasData
from ibis.backends.pyspark.datatypes import PySparkSchema, PySparkType
Expand All @@ -33,6 +34,8 @@
import pandas as pd
import pyarrow as pa

PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4")


def normalize_filenames(source_list):
# Promote to list
Expand Down Expand Up @@ -127,7 +130,7 @@ def __exit__(self, exc_type, exc_value, traceback):
"""No-op for compatibility."""


class Backend(SQLBackend, CanCreateDatabase):
class Backend(SQLBackend, CanListCatalog, CanCreateDatabase):
name = "pyspark"
compiler = PySparkCompiler()

Expand Down Expand Up @@ -221,6 +224,11 @@ def current_database(self) -> str:
[(db,)] = self._session.sql("SELECT CURRENT_DATABASE()").collect()
return db

@property
def current_catalog(self) -> str:
[(catalog,)] = self._session.sql("SELECT CURRENT_CATALOG()").collect()
return catalog

@contextlib.contextmanager
def _active_database(self, name: str | None):
if name is None:
Expand All @@ -233,10 +241,29 @@ def _active_database(self, name: str | None):
finally:
self._session.catalog.setCurrentDatabase(current)

def list_databases(self, like: str | None = None) -> list[str]:
databases = [
db.namespace for db in self._session.sql("SHOW DATABASES").collect()
]
@contextlib.contextmanager
def _active_catalog(self, name: str | None):
if name is None or PYSPARK_LT_34:
yield
return
current = self.current_catalog
try:
self._session.catalog.setCurrentCatalog(name)
yield
finally:
self._session.catalog.setCurrentCatalog(current)

def list_catalogs(self, like: str | None = None) -> list[str]:
catalogs = [res.catalog for res in self._session.sql("SHOW CATALOGS").collect()]
return self._filter_with_like(catalogs, like)

def list_databases(
self, like: str | None = None, catalog: str | None = None
) -> list[str]:
with self._active_catalog(catalog):
databases = [
db.namespace for db in self._session.sql("SHOW DATABASES").collect()
]
return self._filter_with_like(databases, like)

def list_tables(
Expand All @@ -250,14 +277,21 @@ def list_tables(
A pattern to use for listing tables.
database
Database to list tables from. Default behavior is to show tables in
the current database.
the current catalog and database.
To specify a table in a separate catalog, you can pass in the
catalog and database as a string `"catalog.database"`, or as a tuple of
strings `("catalog", "database")`.
"""
tables = [
row.tableName
for row in self._session.sql(
f"SHOW TABLES IN {database or self.current_database}"
).collect()
]
table_loc = self._to_sqlglot_table(database)
catalog, db = self._to_catalog_db_tuple(table_loc)
with self._active_catalog(catalog):
tables = [
row.tableName
for row in self._session.sql(
f"SHOW TABLES IN {db or self.current_database}"
).collect()
]
return self._filter_with_like(tables, like)

def _wrap_udf_to_return_pandas(self, func, output_dtype):
Expand Down Expand Up @@ -319,6 +353,8 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> _PySparkCursor:
def create_database(
self,
name: str,
*,
catalog: str | None = None,
path: str | Path | None = None,
force: bool = False,
) -> Any:
Expand All @@ -328,6 +364,8 @@ def create_database(
----------
name
Database name
catalog
Catalog to create database in (defaults to ``current_catalog``)
path
Path where to store the database data; otherwise uses Spark default
force
Expand All @@ -347,16 +385,21 @@ def create_database(
this=sg.to_identifier(name),
properties=properties,
)
with self._safe_raw_sql(sql):
pass
with self._active_catalog(catalog):
with self._safe_raw_sql(sql):
pass

def drop_database(self, name: str, force: bool = False) -> Any:
def drop_database(
self, name: str, *, catalog: str | None = None, force: bool = False
) -> Any:
"""Drop a Spark database.
Parameters
----------
name
Database name
catalog
Catalog containing database to drop (defaults to ``current_catalog``)
force
If False, Spark throws exception if database is not empty or
database does not exist
Expand All @@ -365,8 +408,9 @@ def drop_database(self, name: str, force: bool = False) -> Any:
sql = sge.Drop(
kind="DATABASE", exist=force, this=sg.to_identifier(name), cascade=force
)
with self._safe_raw_sql(sql):
pass
with self._active_catalog(catalog):
with self._safe_raw_sql(sql):
pass

def get_schema(
self,
Expand All @@ -382,7 +426,7 @@ def get_schema(
table_name
Table name. May be fully qualified
catalog
Unsupported in PySpark backend.
Catalog to use
database
Database to use to get the active database.
Expand All @@ -392,7 +436,10 @@ def get_schema(
An ibis schema
"""
with self._active_database(database):

table_loc = self._to_sqlglot_table((catalog, database))
catalog, db = self._to_catalog_db_tuple(table_loc)
with self._active_catalog(catalog), self._active_database(db):
df = self._session.table(table_name)
struct = PySparkType.to_ibis(df.schema)

Expand Down Expand Up @@ -421,8 +468,12 @@ def create_table(
Mutually exclusive with `obj`, creates an empty table with a schema
database
Database name
To specify a table in a separate catalog, you can pass in the
catalog and database as a string `"catalog.database"`, or as a tuple of
strings `("catalog", "database")`.
temp
Whether the new table is temporary
Whether the new table is temporary (unsupported)
overwrite
If `True`, overwrite existing data
format
Expand All @@ -443,22 +494,25 @@ def create_table(
"PySpark backend does not yet support temporary tables"
)

table_loc = self._to_sqlglot_table(database)
catalog, db = self._to_catalog_db_tuple(table_loc)

if obj is not None:
table = obj if isinstance(obj, ir.Expr) else ibis.memtable(obj)
query = self.compile(table)
mode = "overwrite" if overwrite else "error"
with self._active_database(database):
with self._active_catalog(catalog), self._active_database(db):
self._run_pre_execute_hooks(table)
df = self._session.sql(query)
df.write.saveAsTable(name, format=format, mode=mode)
elif schema is not None:
schema = PySparkSchema.from_ibis(schema)
with self._active_database(database):
with self._active_catalog(catalog), self._active_database(db):
self._session.catalog.createTable(name, schema=schema, format=format)
else:
raise com.IbisError("The schema or obj parameter is required")

return self.table(name, database=database)
return self.table(name, database=db)

def create_view(
self,
Expand Down
22 changes: 22 additions & 0 deletions ibis/backends/pyspark/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import ibis


def test_catalog_db_args(con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})

# create a table in specified catalog and db
con.create_table(
"t2", database=(con.current_catalog, "default"), obj=t, overwrite=True
)

assert "t2" not in con.list_tables()
assert "t2" in con.list_tables(database="default")
assert "t2" in con.list_tables(database="spark_catalog.default")
assert "t2" in con.list_tables(database=("spark_catalog", "default"))

con.drop_table("t2", database="spark_catalog.default")

assert "t2" not in con.list_tables(database="default")
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_version(backend):
"bigquery",
"mysql",
"impala",
"pyspark",
"flink",
],
reason="backend does not support catalogs",
Expand All @@ -43,6 +42,7 @@ def test_version(backend):
raises=NotImplementedError,
reason="current_catalog isn't implemented",
)
@pytest.mark.xfail_version(pyspark=["pyspark<3.4"])
def test_catalog_consistency(backend, con):
catalogs = con.list_catalogs()
assert isinstance(catalogs, list)
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,12 @@ def test_insert_from_memtable(con, temp_table):
"pandas",
"polars",
"flink",
"pyspark",
"sqlite",
],
raises=AttributeError,
reason="doesn't support the common notion of a catalog",
)
@pytest.mark.xfail_version(pyspark=["pyspark<3.4"])
def test_list_catalogs(con):
# Every backend has its own databases
test_catalogs = {
Expand All @@ -634,6 +634,7 @@ def test_list_catalogs(con):
"risingwave": {"dev"},
"snowflake": {"IBIS_TESTING"},
"trino": {"memory"},
"pyspark": {"spark_catalog"},
}
result = set(con.list_catalogs())
assert test_catalogs[con.name] <= result
Expand All @@ -647,7 +648,7 @@ def test_list_catalogs(con):
"polars",
],
raises=AttributeError,
reason="doesn't support the common notion of a catalog",
reason="doesn't support the common notion of a database",
)
def test_list_database_contents(con):
# Every backend has its own databases
Expand Down

0 comments on commit 2c1a58e

Please sign in to comment.