Skip to content

Commit

Permalink
fix(trino): enable passing the database argument when accessing tables
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Oct 19, 2023
1 parent 4bd021f commit e7ce43e
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 44 deletions.
12 changes: 12 additions & 0 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@
}


_SQLALCHEMY_TO_SQLGLOT_DIALECT = {
# sqlalchemy dialects of backends not listed here match the sqlglot dialect
# name
"mssql": "tsql",
"postgresql": "postgres",
"default": "duckdb",
# druid allows double quotes for identifiers, like postgres:
# https://druid.apache.org/docs/latest/querying/sql#identifiers-and-literals
"druid": "postgres",
}


class Database:
"""Generic Database class."""

Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def table(self, name: str, database: str | None = None) -> ir.Table:
)
qualified_name = self._fully_qualified_name(name, database)
schema = self.get_schema(qualified_name)
node = ops.DatabaseTable(name, schema, self, namespace=database)
node = ops.DatabaseTable(
name, schema, self, namespace=ops.Namespace(database=database)
)
return node.to_expr()

def _fully_qualified_name(self, name, database):
Expand Down
51 changes: 27 additions & 24 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,9 @@ def drop_table(
"Dropping tables from a different database is not yet implemented"
)

t = self._get_sqla_table(name, schema=database, autoload=False)
t = self._get_sqla_table(
name, namespace=ops.Namespace(database=database), autoload=False
)
with self.begin() as bind:
t.drop(bind=bind, checkfirst=force)

Expand All @@ -458,7 +460,7 @@ def drop_table(
del self._schemas[qualified_name]

def truncate_table(self, name: str, database: str | None = None) -> None:
t = self._get_sqla_table(name, schema=database)
t = self._get_sqla_table(name, namespace=ops.Namespace(database=database))
with self.begin() as con:
con.execute(t.delete())

Expand Down Expand Up @@ -490,7 +492,12 @@ def _new_sa_metadata():
return sa.MetaData()

def _get_sqla_table(
self, name: str, schema: str | None = None, autoload: bool = True, **_: Any
self,
name: str,
*,
namespace: ops.Namespace = ops.Namespace(), # noqa: B008
autoload: bool = True,
**_: Any,
) -> sa.Table:
meta = self._new_sa_metadata()
with warnings.catch_warnings():
Expand All @@ -503,7 +510,7 @@ def _get_sqla_table(
table = sa.Table(
name,
meta,
schema=schema,
schema=namespace.schema,
autoload_with=self.con if autoload else None,
quote=self.compiler.translator_class._quote_table_names,
)
Expand Down Expand Up @@ -615,16 +622,9 @@ def table(
Table
Table expression
"""
namespace = schema
if database is not None:
if not isinstance(database, str):
raise com.IbisTypeError(
f"`database` must be a string; got {type(database)}"
)
if database != self.current_database:
return self.database(name=database).table(name=name, schema=schema)
namespace = ops.Namespace(schema=schema, database=database)

sqla_table = self._get_sqla_table(name, schema=schema)
sqla_table = self._get_sqla_table(name, namespace=namespace)

schema = self._schema_from_sqla_table(
sqla_table, schema=self._schemas.get(name)
Expand All @@ -637,9 +637,9 @@ def table(
def _insert_dataframe(
self, table_name: str, df: pd.DataFrame, overwrite: bool
) -> None:
schema = self._current_schema
namespace = ops.Namespace(schema=self._current_schema)

t = self._get_sqla_table(table_name, schema=schema)
t = self._get_sqla_table(table_name, namespace=namespace)
with self.con.begin() as con:
if overwrite:
con.execute(t.delete())
Expand Down Expand Up @@ -701,7 +701,9 @@ def insert(
self.drop_table(table_name, database=database)
self.create_table(table_name, schema=to_table_schema, database=database)

to_table = self._get_sqla_table(table_name, schema=database)
to_table = self._get_sqla_table(
table_name, namespace=ops.Namespace(database=database)
)

from_table_expr = obj

Expand All @@ -713,7 +715,9 @@ def insert(
with self.begin() as bind:
bind.execute(to_table.insert().from_select(columns, compiled))
elif isinstance(obj, (list, dict)):
to_table = self._get_sqla_table(table_name, schema=database)
to_table = self._get_sqla_table(
table_name, namespace=ops.Namespace(database=database)
)

with self.begin() as bind:
if overwrite:
Expand Down Expand Up @@ -904,7 +908,10 @@ class AlchemyCrossSchemaBackend(BaseAlchemyBackend):
currently active one.
"""

def _get_table_identifier(self, *, name, schema, database):
def _get_table_identifier(self, *, name, namespace):
database = namespace.database
schema = namespace.schema

if schema is None:
schema = self.current_schema

Expand Down Expand Up @@ -935,13 +942,9 @@ def _get_table_identifier(self, *, name, schema, database):
return table

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
database: str | None = None,
**_: Any,
self, name: str, namespace: ops.Namespace, **_: Any
) -> sa.Table:
table = self._get_table_identifier(name=name, schema=schema, database=database)
table = self._get_table_identifier(name=name, namespace=namespace)
metadata_query = sg.select(STAR).from_(table).limit(0).sql(dialect=self.name)
pairs = self._metadata(metadata_query)
ibis_schema = ibis.schema(pairs)
Expand Down
13 changes: 11 additions & 2 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import functools

import sqlalchemy as sa
import sqlglot as sg
import toolz
from sqlalchemy import sql

import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.operations as ops
from ibis.backends.base import _SQLALCHEMY_TO_SQLGLOT_DIALECT
from ibis.backends.base.sql.alchemy.translator import (
AlchemyContext,
AlchemyExprTranslator,
Expand Down Expand Up @@ -93,15 +95,22 @@ def _format_table(self, op):
translator = ctx.compiler.translator_class(op, ctx)

if isinstance(op, ops.DatabaseTable):
result = op.source._get_sqla_table(op.name, schema=op.namespace)
namespace = op.namespace
result = op.source._get_sqla_table(op.name, namespace=namespace)
elif isinstance(op, ops.UnboundTable):
# use SQLAlchemy's TableClause for unbound tables
name = op.name
namespace = op.namespace
result = sa.Table(
op.name,
name,
sa.MetaData(),
*translator._schema_to_sqlalchemy_columns(op.schema),
quote=translator._quote_table_names,
)
dialect = translator._dialect_name
result.fullname = sg.table(
name, db=namespace.schema, catalog=namespace.database
).sql(dialect=_SQLALCHEMY_TO_SQLGLOT_DIALECT.get(dialect, dialect))
elif isinstance(op, ops.SQLQueryResult):
columns = translator._schema_to_sqlalchemy_columns(op.schema)
result = sa.text(op.query).columns(*columns)
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,13 @@ def _format_table(self, op):
if (name := op.name) is None:
raise com.RelationError(f"Table did not have a name: {op!r}")

namespace = getattr(op, "namespace", None)
catalog = getattr(namespace, "database", None)
db = getattr(namespace, "schema", None)
result = sg.table(
name,
db=getattr(op, "namespace", None),
db=db,
catalog=catalog,
quoted=self.parent.translator_class._quote_identifiers,
).sql(dialect=self.parent.translator_class._dialect_name)
elif ctx.is_extracted(op):
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ def table(self, name: str, database: str | None = None) -> ir.Table:
"""
schema = self.get_schema(name, database=database)
op = ops.DatabaseTable(
name=name, schema=schema, source=self, namespace=database
name=name,
schema=schema,
source=self,
namespace=ops.Namespace(database=database),
)
return op.to_expr()

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/compiler/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _physical_table(op: ops.PhysicalTable, **_):

@translate_rel.register
def _database_table(op: ops.DatabaseTable, *, name, namespace, **_):
return sg.table(name, db=namespace)
return sg.table(name, db=namespace.schema, catalog=namespace.database)


def replace_tables_with_star_selection(node, alias=None):
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _has_table(self, connection, table_name: str, schema) -> bool:
return bool(connection.execute(query).scalar())

def _get_sqla_table(
self, name: str, schema: str | None = None, autoload: bool = True, **kwargs: Any
self, name: str, autoload: bool = True, **kwargs: Any
) -> sa.Table:
with warnings.catch_warnings():
warnings.filterwarnings(
Expand All @@ -136,6 +136,4 @@ def _get_sqla_table(
),
category=sa.exc.SAWarning,
)
return super()._get_sqla_table(
name, schema=schema, autoload=autoload, **kwargs
)
return super()._get_sqla_table(name, autoload=autoload, **kwargs)
5 changes: 4 additions & 1 deletion ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def table(
_, quoted, unquoted = fully_qualified_re.search(qualified_name).groups()
unqualified_name = quoted or unquoted
node = ops.DatabaseTable(
unqualified_name, schema, self, namespace=database
unqualified_name,
schema,
self,
namespace=ops.Namespace(schema=database, database=catalog),
) # TODO(chloeh13q): look into namespacing with catalog + db
return node.to_expr()

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class ImpalaTable(ir.Table):
@property
def _qualified_name(self) -> str:
op = self.op()
return sg.table(op.name, db=op.namespace).sql(dialect="hive")
return sg.table(op.name, catalog=op.namespace.database).sql(dialect="hive")

@property
def _unqualified_name(self) -> str:
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def table(self, name: str, database: str | None = None) -> ir.Table:
qualified_name = self._fully_qualified_name(name, database)

schema = self.get_schema(qualified_name)
node = ops.DatabaseTable(name, schema, self, namespace=database)
node = ops.DatabaseTable(
name, schema, self, namespace=ops.Namespace(database=database)
)
return PySparkTable(node)

def create_database(
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/pyspark/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ class PySparkTable(ir.Table):
@property
def _qualified_name(self) -> str:
op = self.op()
return sg.table(op.name, db=op.namespace, quoted=True).sql(dialect="spark")
return sg.table(
op.name, db=op.namespace.schema, catalog=op.namespace.database, quoted=True
).sql(dialect="spark")

@property
def _database(self) -> str:
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/trino/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import ibis
import ibis.common.exceptions as exc
from ibis import udf, util
from ibis.backends.trino.tests.conftest import (
TRINO_HOST,
Expand Down Expand Up @@ -161,3 +162,14 @@ def test_table_access_from_connection_without_catalog_or_schema():
assert con.current_schema is None

assert t.count().execute()


def test_table_access_database_schema(con):
t = con.table("region", schema="sf1", database="tpch")
assert t.count().execute()

with pytest.raises(exc.IbisError, match="Cannot specify both"):
con.table("region", schema="tpch.sf1", database="tpch")

with pytest.raises(exc.IbisError, match="Cannot specify both"):
con.table("region", schema="tpch.sf1", database="system")
15 changes: 11 additions & 4 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ibis.common.annotations import annotated, attribute
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import Deferred
from ibis.common.grounds import Immutable
from ibis.common.grounds import Concrete, Immutable
from ibis.common.patterns import Between, Coercible, Eq
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Column, Named, Node, Scalar, Value
Expand Down Expand Up @@ -69,6 +69,12 @@ def to_expr(self):
TableNode = Relation


@public
class Namespace(Concrete):
database: Optional[str] = None
schema: Optional[str] = None


@public
class PhysicalTable(Relation, Named):
pass
Expand All @@ -80,19 +86,20 @@ class PhysicalTable(Relation, Named):
class UnboundTable(PhysicalTable):
schema: Schema
name: Optional[str] = None
namespace: Namespace = Namespace()

def __init__(self, schema, name) -> None:
def __init__(self, schema, name, namespace) -> None:
if name is None:
name = genname()
super().__init__(schema=schema, name=name)
super().__init__(schema=schema, name=name, namespace=namespace)


@public
class DatabaseTable(PhysicalTable):
name: str
schema: Schema
source: Any
namespace: Optional[str] = None
namespace: Namespace = Namespace()


@public
Expand Down
4 changes: 2 additions & 2 deletions ibis/tests/expr/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self):
pytest.importorskip("sqlalchemy")
self.tables = {}

def table(self, name, database=None):
def table(self, name, **_):
schema = self.get_schema(name)
return self._inject_table(name, schema)

Expand All @@ -139,7 +139,7 @@ def _inject_table(self, name, schema):
self.tables[name] = table_from_schema(name, sa.MetaData(), schema)
return ops.DatabaseTable(source=self, name=name, schema=schema).to_expr()

def _get_sqla_table(self, name, schema=None, **kwargs):
def _get_sqla_table(self, name, **_):
return self.tables[name]


Expand Down

0 comments on commit e7ce43e

Please sign in to comment.