Skip to content

Commit

Permalink
feat(sql): add database argument to list_schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Aug 8, 2023
1 parent a711435 commit 22ceba7
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 53 deletions.
16 changes: 9 additions & 7 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,19 @@ def drop_schema(
"""

@abc.abstractmethod
def list_schemas(self, like: str | None = None) -> list[str]:
def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
"""List existing schemas in the current connection.
Parameters
----------
like
A pattern in Python's regex format to filter returned schema
names.
database
The database to list schemas from. If `None`, the current database
is searched.
Returns
-------
Expand Down Expand Up @@ -745,10 +750,7 @@ def database(self, name: str | None = None) -> Database:
return Database(name=name or self.current_database, client=self)

@staticmethod
def _filter_with_like(
values: Iterable[str],
like: str | None = None,
) -> list[str]:
def _filter_with_like(values: Iterable[str], like: str | None = None) -> list[str]:
"""Filter names with a `like` pattern (regex).
The methods `list_databases` and `list_tables` accept a `like`
Expand All @@ -771,10 +773,10 @@ def _filter_with_like(
Names filtered by the `like` pattern.
"""
if like is None:
return list(values)
return sorted(values)

pattern = re.compile(like)
return sorted(filter(lambda t: pattern.findall(t), values))
return sorted(filter(pattern.findall, values))

@abc.abstractmethod
def list_tables(
Expand Down
23 changes: 16 additions & 7 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,22 @@ def _create_table_as(element, compiler, **kw):


class AlchemyCanCreateSchema(CanCreateSchema):
def list_schemas(self, like: str | None = None) -> list[str]:
return self._filter_with_like(self.inspector.get_schema_names(), like)
def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
schema = ".".join(filter(None, (database, "information_schema")))
sch = sa.table(
"schemata",
sa.column("catalog_name", sa.TEXT()),
sa.column("schema_name", sa.TEXT()),
schema=schema,
)

query = sa.select(sch.c.schema_name)

with self.begin() as con:
schemas = list(con.execute(query).scalars())
return self._filter_with_like(schemas, like=like)


class BaseAlchemyBackend(BaseSQLBackend):
Expand Down Expand Up @@ -153,11 +167,6 @@ def list_tables(self, like=None, database=None):
views = self.inspector.get_view_names(schema=database)
return self._filter_with_like(tables + views, like)

def list_databases(self, like=None):
"""List databases in the current server."""
databases = self.inspector.get_schema_names()
return self._filter_with_like(databases, like)

@property
def inspector(self):
if self._inspector is None:
Expand Down
12 changes: 9 additions & 3 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,14 @@ def get_schema(self, name, database=None):
table = self.client.get_table(table_ref)
return schema_from_bigquery_table(table)

def list_schemas(self, like=None):
def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
results = [
dataset.dataset_id
for dataset in self.client.list_datasets(project=self.data_project)
for dataset in self.client.list_datasets(
project=database if database is not None else self.data_project
)
]
return self._filter_with_like(results, like)

Expand All @@ -472,7 +476,9 @@ def list_schemas(self, like=None):
def list_databases(self, like=None):
return self.list_schemas(like=like)

def list_tables(self, like=None, database=None):
def list_tables(
self, like: str | None = None, database: str | None = None
) -> list[str]:
project, dataset = self._parse_project_and_dataset(database)
dataset_ref = bq.DatasetReference(project, dataset)
result = [table.table_id for table in self.client.list_tables(dataset_ref)]
Expand Down
11 changes: 9 additions & 2 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,15 @@ def drop_database(self, name: str, force: bool = False) -> None:
"DataFusion does not support dropping databases"
)

def list_schemas(self, like: str | None = None) -> list[str]:
return self._filter_with_like(self._context.catalog().names(), like=like)
def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
return self._filter_with_like(
self._context.catalog(
database if database is not None else "datafusion"
).names(),
like=like,
)

def create_schema(
self, name: str, database: str | None = None, force: bool = False
Expand Down
35 changes: 17 additions & 18 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,30 @@ def list_databases(self, like: str | None = None) -> list[str]:
schema="information_schema",
)

query = sa.select(sa.distinct(s.c.catalog_name)).order_by(s.c.catalog_name)
query = sa.select(sa.distinct(s.c.catalog_name))
with self.begin() as con:
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.current_schema()))

def list_schemas(self, like: str | None = None) -> list[str]:
s = sa.table(
"schemata",
sa.column("catalog_name", sa.TEXT()),
sa.column("schema_name", sa.TEXT()),
schema="information_schema",
def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
# override duckdb because all databases are always visible
text = """\
SELECT schema_name
FROM information_schema.schemata
WHERE catalog_name = :database"""
query = sa.text(text).bindparams(
database=database if database is not None else self.current_database
)

query = (
sa.select(s.c.schema_name)
.where(s.c.catalog_name == sa.func.current_database())
.order_by(s.c.schema_name)
)
with self.begin() as con:
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)
schemas = list(con.execute(query).scalars())
return self._filter_with_like(schemas, like=like)

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.current_schema()))

@staticmethod
def _convert_kwargs(kwargs: MutableMapping) -> None:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def current_database(self) -> str:

def list_databases(self, like: str | None = None) -> list[str]:
s = sa.table("databases", sa.column("name", sa.VARCHAR()), schema="sys")
query = sa.select(sa.distinct(s.c.name)).select_from(s).order_by(s.c.name)
query = sa.select(s.c.name)

with self.begin() as con:
results = list(con.execute(query).scalars())
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def column_reflect(inspector, table, column_info):

return meta

def list_databases(self, like: str | None = None) -> list[str]:
# In MySQL, "database" and "schema" are synonymous
databases = self.inspector.get_schema_names()
return self._filter_with_like(databases, like)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
if (
re.search(r"^\s*SELECT\s", query, flags=re.MULTILINE | re.IGNORECASE)
Expand Down
12 changes: 9 additions & 3 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,16 @@ def connect(dbapi_connection, connection_record):
super().do_connect(engine)

def list_databases(self, like=None) -> list[str]:
query = "SELECT datname FROM pg_catalog.pg_database WHERE NOT datistemplate"
# http://dba.stackexchange.com/a/1304/58517
dbs = sa.table(
"pg_database",
sa.column("datname", sa.TEXT()),
sa.column("datistemplate", sa.BOOLEAN()),
schema="pg_catalog",
)
query = sa.select(dbs.c.datname).where(sa.not_(dbs.c.datistemplate))
with self.begin() as con:
# http://dba.stackexchange.com/a/1304/58517
databases = list(con.exec_driver_sql(query).scalars())
databases = list(con.execute(query).scalars())

return self._filter_with_like(databases, like)

Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,13 @@ def version(self):
def current_database(self) -> str:
return self._catalog.currentDatabase()

def list_databases(self, like=None):
def list_databases(self, like: str | None = None) -> list[str]:
databases = [db.name for db in self._catalog.listDatabases()]
return self._filter_with_like(databases, like)

def list_tables(self, like=None, database=None):
def list_tables(
self, like: str | None = None, database: str | None = None
) -> list[str]:
tables = [
t.name
for t in self._catalog.listTables(dbName=database or self.current_database)
Expand Down
25 changes: 17 additions & 8 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,17 +430,26 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
typ = parse(FIELD_ID_TO_NAME[type_code]).copy(nullable=is_nullable)
yield name, typ

def list_databases(self, like=None) -> list[str]:
d = sa.table(
"databases",
sa.column("database_name", sa.TEXT()),
schema="information_schema",
)
query = sa.select(d.c.database_name).order_by(d.c.database_name)
def list_databases(self, like: str | None = None) -> list[str]:
with self.begin() as con:
databases = list(con.execute(query).scalars())
databases = [
row["name"] for row in con.exec_driver_sql("SHOW DATABASES").mappings()
]
return self._filter_with_like(databases, like)

def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
query = "SHOW SCHEMAS"

if database is not None:
query += f" IN {self._quote(database)}"

with self.begin() as con:
schemata = [row["name"] for row in con.exec_driver_sql(query).mappings()]

return self._filter_with_like(schemata, like)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
import pyarrow.parquet as pq

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def temp_db(con):
def temp_schema(con, temp_db):
schema = gen_name("tmp_schema")
con.raw_sql(f"CREATE SCHEMA {temp_db}.{schema}")
assert schema in con.list_schemas(database=temp_db)
yield schema
con.raw_sql(f"DROP SCHEMA {temp_db}.{schema}")

Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ def test_version(backend):

# 1. `current_database` returns '.', but isn't listed in list_databases()
@pytest.mark.never(
["polars", "dask", "pandas"],
["polars", "dask", "pandas", "druid", "oracle"],
reason="backend does not support databases",
raises=AttributeError,
)
@pytest.mark.notimpl(["oracle"], raises=AssertionError)
@pytest.mark.notimpl(
["datafusion"],
raises=NotImplementedError,
Expand Down
21 changes: 21 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ def test_insert_from_memtable(alchemy_con, alchemy_temp_table):
assert alchemy_con.tables[table_name].schema() == ibis.schema({"x": "int64"})


@pytest.mark.notyet(
["oracle"],
raises=AttributeError,
reason="oracle doesn't support the common notion of a database",
)
def test_list_databases(alchemy_con):
# Every backend has its own databases
test_databases = {
Expand Down Expand Up @@ -1356,3 +1361,19 @@ def test_create_database_schema(con_create_database_schema):
con_create_database_schema.drop_schema(schema, database=database)
finally:
con_create_database_schema.drop_database(database)


@pytest.mark.notyet(["datafusion"], reason="cannot list or drop databases")
def test_list_databases_schemas(con_create_database_schema):
database = gen_name("test_create_database")
con_create_database_schema.create_database(database)
try:
schema = gen_name("test_create_database_schema")
con_create_database_schema.create_schema(schema, database=database)

try:
assert schema in con_create_database_schema.list_schemas(database=database)
finally:
con_create_database_schema.drop_schema(schema, database=database)
finally:
con_create_database_schema.drop_database(database)

0 comments on commit 22ceba7

Please sign in to comment.