Skip to content

Commit

Permalink
feat(backends): add current_schema API
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Aug 2, 2023
1 parent 7423eb9 commit 955a9d0
Show file tree
Hide file tree
Showing 21 changed files with 190 additions and 112 deletions.
64 changes: 35 additions & 29 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,31 @@ def to_delta(
write_deltalake(path, batch_reader, **kwargs)


class CanCreateDatabase(abc.ABC):
class CanListDatabases(abc.ABC):
@abc.abstractmethod
def list_databases(self, like: str | None = None) -> list[str]:
"""List existing databases in the current connection.
Parameters
----------
like
A pattern in Python's regex format to filter returned database
names.
Returns
-------
list[str]
The database names that exist in the current connection, that match
the `like` pattern if provided.
"""

@property
@abc.abstractmethod
def current_database(self) -> str:
"""The current database in use."""


class CanCreateDatabase(CanListDatabases):
@abc.abstractmethod
def create_database(self, name: str, force: bool = False) -> None:
"""Create a new database.
Expand All @@ -534,23 +558,14 @@ def create_database(self, name: str, force: bool = False) -> None:

@abc.abstractmethod
def drop_database(self, name: str, force: bool = False) -> None:
"""Drop a database with name `name`."""

@abc.abstractmethod
def list_databases(self, like: str | None = None) -> list[str]:
"""List existing databases in the current connection.
"""Drop a database with name `name`.
Parameters
----------
like
A pattern in Python's regex format to filter returned database
names.
Returns
-------
list[str]
The database names that exist in the current connection, that match
the `like` pattern if provided.
name
Database to drop.
force
If `False`, an exception is raised if the database does not exist.
"""


Expand Down Expand Up @@ -606,6 +621,11 @@ def list_schemas(self, like: str | None = None) -> list[str]:
the `like` pattern if provided.
"""

@property
@abc.abstractmethod
def current_schema(self) -> str:
"""Return the current schema."""


class BaseBackend(abc.ABC, _FileIOHandler):
"""Base backend class.
Expand Down Expand Up @@ -724,20 +744,6 @@ def database(self, name: str | None = None) -> Database:
"""
return Database(name=name or self.current_database, client=self)

@property
@abc.abstractmethod
def current_database(self) -> str | None:
"""Return the name of the current database.
Backends that don't support different databases will return None.
Returns
-------
str | None
Name of the current database or `None` if the backend supports the
concept of a database but doesn't support multiple databases.
"""

@staticmethod
def _filter_with_like(
values: Iterable[str],
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class BaseAlchemyBackend(BaseSQLBackend):
supports_temporary_tables = True
_temporary_prefix = "TEMPORARY"

def _scalar_query(self, query):
method = "exec_driver_sql" if isinstance(query, str) else "execute"
with self.begin() as con:
return getattr(con, method)(query).scalar()

def _compile_type(self, dtype) -> str:
dialect = self.con.dialect
return sa.types.to_instance(
Expand Down Expand Up @@ -451,11 +456,6 @@ def schema(self, name: str) -> sch.Schema:
"""
return self.database().schema(name)

@property
def current_database(self) -> str | None:
"""The name of the current database this client is connected to."""
return self.database_name

def _log(self, sql):
try:
query_str = str(sql)
Expand Down
19 changes: 15 additions & 4 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import contextlib
import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping
from urllib.parse import parse_qs, urlparse

Expand All @@ -17,7 +18,7 @@
import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base import CanCreateSchema, Database
from ibis.backends.base import CanCreateSchema, CanListDatabases, Database
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.bigquery.client import (
BigQueryCursor,
Expand Down Expand Up @@ -71,7 +72,7 @@ def _create_client_info_gapic(application_name):
return ClientInfo(user_agent=_create_user_agent(application_name))


class Backend(BaseSQLBackend, CanCreateSchema):
class Backend(BaseSQLBackend, CanCreateSchema, CanListDatabases):
name = "bigquery"
compiler = BigQueryCompiler
supports_in_memory_tables = False
Expand Down Expand Up @@ -272,7 +273,7 @@ def drop_schema(

def table(self, name: str, database: str | None = None) -> ir.TableExpr:
if database is None:
database = f"{self.data_project}.{self.current_database}"
database = f"{self.data_project}.{self.current_schema}"
table_id = self._fully_qualified_name(name, database)
t = super().table(table_id)
bq_table = self.client.get_table(table_id)
Expand Down Expand Up @@ -326,7 +327,17 @@ def raw_sql(self, query: str, results=False, params=None):
return self._execute(query, results=results, query_parameters=query_parameters)

@property
def current_database(self) -> str | None:
def current_database(self) -> str:
warnings.warn(
"current_database will return the current *data project* in ibis 7.0.0; "
"use current_schema for the current BigQuery dataset",
category=FutureWarning,
)
# TODO: return self.data_project in ibis 7.0.0
return self.dataset

@property
def current_schema(self) -> str | None:
return self.dataset

def database(self, name=None):
Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/bigquery/tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def test_list_tables(con):


def test_current_database(con, dataset_id):
db = con.current_database
with pytest.warns(FutureWarning, match="data project"):
db = con.current_database
assert db == dataset_id
assert db == con.dataset_id
assert con.list_tables(database=db, like="alltypes") == con.list_tables(
Expand Down Expand Up @@ -352,7 +353,7 @@ def test_approx_median(alltypes):
def test_create_table_bignumeric(con, temp_table):
schema = ibis.schema({'col1': dt.Decimal(76, 38)})
temporary_table = con.create_table(temp_table, schema=schema)
con.raw_sql(f"INSERT {con.current_database}.{temp_table} (col1) VALUES (10.2)")
con.raw_sql(f"INSERT {con.current_schema}.{temp_table} (col1) VALUES (10.2)")
df = temporary_table.execute()
assert df.shape == (1, 1)

Expand All @@ -361,7 +362,7 @@ def test_geography_table(con, temp_table):
schema = ibis.schema({'col1': dt.GeoSpatial(geotype="geography", srid=4326)})
temporary_table = con.create_table(temp_table, schema=schema)
con.raw_sql(
f"INSERT {con.current_database}.{temp_table} (col1) VALUES (ST_GEOGPOINT(1,3))"
f"INSERT {con.current_schema}.{temp_table} (col1) VALUES (ST_GEOGPOINT(1,3))"
)
df = temporary_table.execute()
assert df.shape == (1, 1)
Expand All @@ -377,7 +378,7 @@ def test_timestamp_table(con, temp_table):
)
temporary_table = con.create_table(temp_table, schema=schema)
con.raw_sql(
f"INSERT {con.current_database}.{temp_table} (datetime_col, timestamp_col) VALUES (CURRENT_DATETIME(), CURRENT_TIMESTAMP())"
f"INSERT {con.current_schema}.{temp_table} (datetime_col, timestamp_col) VALUES (CURRENT_DATETIME(), CURRENT_TIMESTAMP())"
)
df = temporary_table.execute()
assert df.shape == (1, 2)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def version(self) -> str:
return self.con.server_version

@property
def current_database(self) -> str | None:
def current_database(self) -> str:
with closing(self.raw_sql("SELECT currentDatabase()")) as result:
[(db,)] = result.result_rows
return db
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,13 @@ def do_connect(
self.register(path, table_name=name)

@property
def current_database(self) -> str | None:
def current_database(self) -> str:
raise NotImplementedError()

@property
def current_schema(self) -> str:
return NotImplementedError()

def list_databases(self, like: str | None = None) -> list[str]:
code = "SELECT DISTINCT table_catalog FROM information_schema.tables"
if like:
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class Backend(BaseAlchemyBackend):
compiler = DruidCompiler
supports_create_or_replace = False

@property
def current_database(self) -> str:
# https://druid.apache.org/docs/latest/querying/sql-metadata-tables.html#schemata-table
return "druid"

def do_connect(
self,
host: str = "localhost",
Expand Down Expand Up @@ -80,9 +85,7 @@ def _safe_raw_sql(self, query, *args, **kwargs):
yield con.execute(query, *args, **kwargs)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
query = f"EXPLAIN PLAN FOR {query}"
with self.begin() as con:
result = con.exec_driver_sql(query).scalar()
result = self._scalar_query(f"EXPLAIN PLAN FOR {query}")

(plan,) = json.loads(result)
for column in plan["signature"]:
Expand Down
34 changes: 24 additions & 10 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,24 @@ class Backend(BaseAlchemyBackend, CanCreateSchema):
supports_create_or_replace = True

@property
def current_database(self) -> str | None:
query = sa.select(sa.func.current_database())
def current_database(self) -> str:
return self._scalar_query(sa.select(sa.func.current_database()))

def list_databases(self, like: str | None = None) -> list[str]:
s = sa.table(
"schemata",
sa.column("catalog_name", sa.TEXT()),
schema="information_schema",
)

query = sa.select(sa.distinct(s.c.catalog_name)).order_by(s.c.catalog_name)
with self.begin() as con:
return con.execute(query).scalar()
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.current_schema()))

def list_schemas(self, like: str | None = None) -> list[str]:
s = sa.table(
Expand All @@ -91,14 +105,14 @@ def list_schemas(self, like: str | None = None) -> list[str]:
schema="information_schema",
)

where = s.c.catalog_name == sa.func.current_database()

if like is not None:
where &= s.c.schema_name.like(like)

query = sa.select(s.c.schema_name).select_from(s).where(where)
query = (
sa.select(s.c.schema_name)
.where(s.c.catalog_name == sa.func.current_database())
.order_by(s.c.schema_name)
)
with self.begin() as con:
return list(con.execute(query).scalars())
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)

@staticmethod
def _convert_kwargs(kwargs: MutableMapping) -> None:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _get_list(self, cur):
return list(map(operator.itemgetter(0), tuples))

@property
def current_database(self) -> str | None:
def current_database(self) -> str:
# XXX The parent `Client` has a generic method that calls this same
# method in the backend. But for whatever reason calling this code from
# that method doesn't seem to work. Maybe `con` is a copy?
Expand Down
17 changes: 16 additions & 1 deletion ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def do_connect(
database=database,
driver=f'mssql+{driver}',
)
self.database_name = alchemy_url.database

engine = sa.create_engine(alchemy_url, poolclass=sa.pool.StaticPool)

Expand All @@ -67,6 +66,22 @@ def _metadata(self, query):
for column in bind.execute(query).mappings():
yield column["name"], _type_from_result_set_info(column)

@property
def current_database(self) -> str:
return self._scalar_query(sa.select(sa.func.db_name()))

def list_databases(self, like: str | None = None) -> list[str]:
s = sa.table("databases", sa.column("name", sa.VARCHAR()), schema="sys")
query = sa.select(sa.distinct(s.c.name)).select_from(s).order_by(s.c.name)

with self.begin() as con:
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.schema_name()))

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def do_connect(
driver=f'mysql+{driver}',
)

self.database_name = alchemy_url.database

engine = sa.create_engine(
alchemy_url, poolclass=sa.pool.StaticPool, connect_args=kwargs
)
Expand All @@ -122,6 +120,10 @@ def connect(dbapi_connection, connection_record):

super().do_connect(engine)

@property
def current_database(self) -> str:
return self._scalar_query(sa.select(sa.func.database()))

@staticmethod
def _new_sa_metadata():
meta = sa.MetaData()
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ def do_connect(
# see user tables
# select table_name from user_tables

self.database_name = database # not sure what should go here

# Note: for the moment, we need to pass the `database` in to the `make_url` call
# AND specify it here as the `service_name`. I don't know why.
engine = sa.create_engine(
Expand Down Expand Up @@ -160,6 +158,10 @@ def normalize_name(name):

self.con.dialect.normalize_name = normalize_name

@property
def current_database(self) -> str:
return self._scalar_query("SELECT * FROM global_name")

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
query = f"SELECT * FROM ({query.strip(';')}) FETCH FIRST 0 ROWS ONLY"
with self.begin() as con, con.connection.cursor() as cur:
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def from_dataframe(
def version(self) -> str:
return pd.__version__

@property
def current_database(self) -> str | None:
raise NotImplementedError('pandas backend does not support databases')

def list_tables(self, like=None, database=None):
return self._filter_with_like(list(self.dictionary.keys()), like)

Expand Down
Loading

0 comments on commit 955a9d0

Please sign in to comment.