Skip to content

Commit

Permalink
fix(trino): ensure that a schema is not required upon connection when…
Browse files Browse the repository at this point in the history
… accessing tables with explicit schema
  • Loading branch information
cpcloud committed Oct 17, 2023
1 parent fb5d56d commit 8bde3e0
Show file tree
Hide file tree
Showing 27 changed files with 146 additions and 152 deletions.
78 changes: 41 additions & 37 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING, Any

import sqlalchemy as sa
import sqlglot as sg
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import quoted_name
from sqlalchemy.sql.expression import ClauseElement, Executable
Expand Down Expand Up @@ -38,6 +39,7 @@
AlchemyContext,
AlchemyExprTranslator,
)
from ibis.backends.base.sqlglot import STAR
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
Expand Down Expand Up @@ -902,28 +904,6 @@ class AlchemyCrossSchemaBackend(BaseAlchemyBackend):
currently active one.
"""

@property
@abc.abstractmethod
def use_stmt_prefix(self) -> str:
"""The prefix to use for switching schemas.
Common examples are `USE` and `USE SCHEMA`.
"""

@contextlib.contextmanager
def _use_schema(self, ident: str, current_db: str, current_schema: str) -> None:
use_prefix = self.use_stmt_prefix

try:
with self.begin() as c:
c.exec_driver_sql(f"{use_prefix} {ident}")
yield
finally:
with self.begin() as c:
c.exec_driver_sql(
f"{use_prefix} {self._quote(current_db)}.{self._quote(current_schema)}"
)

def _get_sqla_table(
self,
name: str,
Expand All @@ -937,27 +917,51 @@ def _get_sqla_table(
schema = current_schema
*db, schema = schema.split(".")
db = "".join(db) or database or current_db
ident = ".".join(map(self._quote, filter(None, (db, schema))))

pairs = self._metadata(f"SELECT * FROM {ident}.{self._quote(name)} LIMIT 0")
table = sg.table(
name,
db=schema,
catalog=db,
quoted=self.compiler.translator_class._quote_table_names,
)
metadata_query = sg.select(STAR).from_(table).limit(0).sql(dialect=self.name)
pairs = self._metadata(metadata_query)
ibis_schema = ibis.schema(pairs)

with self._use_schema(ident, current_db, current_schema):
result = self._table_from_schema(name, schema=ibis_schema)
result.schema = self._get_schema_for_table(qualname=ident, schema=schema)
columns = self._columns_from_schema(name, ibis_schema)
result = sa.Table(
name,
sa.MetaData(),
*columns,
quote=self.compiler.translator_class._quote_table_names,
)
result.fullname = table.sql(dialect=self.name)
return result

@abc.abstractmethod
def _get_schema_for_table(self, *, qualname: str, schema: str) -> str:
"""Choose whether to prefix a table with its fully qualified path or schema."""

def drop_table(
self, name: str, database: str | None = None, force: bool = False
) -> None:
name = self._quote(name)
# TODO: handle database quoting
if database is not None:
name = f"{database}.{name}"
drop_stmt = "DROP TABLE" + (" IF EXISTS" * force) + f" {name}"
table = sg.table(name, db=database)
drop_table = sg.exp.Drop(kind="TABLE", exists=force, this=table)
drop_table_sql = drop_table.sql(dialect=self.name)
with self.begin() as con:
con.exec_driver_sql(drop_stmt)
con.exec_driver_sql(drop_table_sql)


@compiles(sa.Table, "trino")
def compile_trino_table(element, compiler, **kw):
return element.fullname


@compiles(sa.Table, "snowflake")
def compile_snowflake_table(element, compiler, **kw):
dialect = compiler.dialect.name
return (
sg.parse_one(element.fullname, into=sg.exp.Table, read=dialect)
.transform(
lambda node: node.__class__(this=node.this, quoted=True)
if isinstance(node, sg.exp.Identifier)
else node
)
.sql(dialect)
)
18 changes: 0 additions & 18 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ class Backend(AlchemyCrossSchemaBackend, CanCreateDatabase, AlchemyCanCreateSche
compiler = SnowflakeCompiler
supports_create_or_replace = True
supports_python_udfs = True
use_stmt_prefix = "USE SCHEMA"

_latest_udf_python_version = (3, 10)

Expand Down Expand Up @@ -813,9 +812,6 @@ def read_json(

return self.table(table)

def _get_schema_for_table(self, *, qualname: str, schema: str) -> str:
return qualname

def read_parquet(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
Expand Down Expand Up @@ -888,20 +884,6 @@ def read_parquet(
return self.table(table)


@compiles(sa.Table, "snowflake")
def compile_table(element, compiler, **kw):
"""Override compilation of leaf tables.
The override is necessary because snowflake-sqlalchemy does not handle
quoting databases and schemas correctly.
"""
schema = element.schema
name = compiler.preparer.quote_identifier(element.name)
if schema is not None:
return f"{schema}.{name}"
return name


@compiles(sa.sql.Join, "snowflake")
def compile_join(element, compiler, **kw):
"""Override compilation of LATERAL joins.
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,15 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype):
"postgres",
"snowflake",
"sqlite",
"bigquery",
"dask",
"trino",
],
raises=AttributeError,
raises=NotImplementedError,
reason="read_delta not yet implemented",
)
@pytest.mark.notyet(["clickhouse"], raises=Exception)
@pytest.mark.notyet(["mssql", "pandas"], raises=PyDeltaTableError)
@pytest.mark.notyet(["bigquery", "dask"], raises=NotImplementedError)
@pytest.mark.notyet(
["druid"],
raises=pa.lib.ArrowTypeError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ FROM (
AVG(t1.l_extendedprice) AS avg_price,
AVG(t1.l_discount) AS avg_disc,
COUNT(*) AS count_order
FROM "hive".ibis_sf1.lineitem AS t1
FROM hive.ibis_sf1.lineitem AS t1
WHERE
t1.l_shipdate <= FROM_ISO8601_DATE('1998-09-02')
GROUP BY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ WITH t0 AS (
t6.r_regionkey AS r_regionkey,
t6.r_name AS r_name,
t6.r_comment AS r_comment
FROM "hive".ibis_sf1.part AS t2
JOIN "hive".ibis_sf1.partsupp AS t3
FROM hive.ibis_sf1.part AS t2
JOIN hive.ibis_sf1.partsupp AS t3
ON t2.p_partkey = t3.ps_partkey
JOIN "hive".ibis_sf1.supplier AS t4
JOIN hive.ibis_sf1.supplier AS t4
ON t4.s_suppkey = t3.ps_suppkey
JOIN "hive".ibis_sf1.nation AS t5
JOIN hive.ibis_sf1.nation AS t5
ON t4.s_nationkey = t5.n_nationkey
JOIN "hive".ibis_sf1.region AS t6
JOIN hive.ibis_sf1.region AS t6
ON t5.n_regionkey = t6.r_regionkey
WHERE
t2.p_size = 15
Expand All @@ -44,12 +44,12 @@ WITH t0 AS (
AND t3.ps_supplycost = (
SELECT
MIN(t3.ps_supplycost) AS "Min(ps_supplycost)"
FROM "hive".ibis_sf1.partsupp AS t3
JOIN "hive".ibis_sf1.supplier AS t4
FROM hive.ibis_sf1.partsupp AS t3
JOIN hive.ibis_sf1.supplier AS t4
ON t4.s_suppkey = t3.ps_suppkey
JOIN "hive".ibis_sf1.nation AS t5
JOIN hive.ibis_sf1.nation AS t5
ON t4.s_nationkey = t5.n_nationkey
JOIN "hive".ibis_sf1.region AS t6
JOIN hive.ibis_sf1.region AS t6
ON t5.n_regionkey = t6.r_regionkey
WHERE
t6.r_name = 'EUROPE' AND t2.p_partkey = t3.ps_partkey
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ WITH t0 AS (
SUM(t4.l_extendedprice * (
1 - t4.l_discount
)) AS revenue
FROM "hive".ibis_sf1.customer AS t2
JOIN "hive".ibis_sf1.orders AS t3
FROM hive.ibis_sf1.customer AS t2
JOIN hive.ibis_sf1.orders AS t3
ON t2.c_custkey = t3.o_custkey
JOIN "hive".ibis_sf1.lineitem AS t4
JOIN hive.ibis_sf1.lineitem AS t4
ON t4.l_orderkey = t3.o_orderkey
WHERE
t2.c_mktsegment = 'BUILDING'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
SELECT
t0.o_orderpriority,
COUNT(*) AS order_count
FROM "hive".ibis_sf1.orders AS t0
FROM hive.ibis_sf1.orders AS t0
WHERE
(
EXISTS(
SELECT
1 AS anon_1
FROM "hive".ibis_sf1.lineitem AS t1
FROM hive.ibis_sf1.lineitem AS t1
WHERE
t1.l_orderkey = t0.o_orderkey AND t1.l_commitdate < t1.l_receiptdate
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ FROM (
SUM(t3.l_extendedprice * (
1 - t3.l_discount
)) AS revenue
FROM "hive".ibis_sf1.customer AS t1
JOIN "hive".ibis_sf1.orders AS t2
FROM hive.ibis_sf1.customer AS t1
JOIN hive.ibis_sf1.orders AS t2
ON t1.c_custkey = t2.o_custkey
JOIN "hive".ibis_sf1.lineitem AS t3
JOIN hive.ibis_sf1.lineitem AS t3
ON t3.l_orderkey = t2.o_orderkey
JOIN "hive".ibis_sf1.supplier AS t4
JOIN hive.ibis_sf1.supplier AS t4
ON t3.l_suppkey = t4.s_suppkey
JOIN "hive".ibis_sf1.nation AS t5
JOIN hive.ibis_sf1.nation AS t5
ON t1.c_nationkey = t4.s_nationkey AND t4.s_nationkey = t5.n_nationkey
JOIN "hive".ibis_sf1.region AS t6
JOIN hive.ibis_sf1.region AS t6
ON t5.n_regionkey = t6.r_regionkey
WHERE
t6.r_name = 'ASIA'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SELECT
SUM(t0.l_extendedprice * t0.l_discount) AS revenue
FROM "hive".ibis_sf1.lineitem AS t0
FROM hive.ibis_sf1.lineitem AS t0
WHERE
t0.l_shipdate >= FROM_ISO8601_DATE('1994-01-01')
AND t0.l_shipdate < FROM_ISO8601_DATE('1995-01-01')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ WITH t0 AS (
t3.l_extendedprice * (
1 - t3.l_discount
) AS volume
FROM "hive".ibis_sf1.supplier AS t2
JOIN "hive".ibis_sf1.lineitem AS t3
FROM hive.ibis_sf1.supplier AS t2
JOIN hive.ibis_sf1.lineitem AS t3
ON t2.s_suppkey = t3.l_suppkey
JOIN "hive".ibis_sf1.orders AS t4
JOIN hive.ibis_sf1.orders AS t4
ON t4.o_orderkey = t3.l_orderkey
JOIN "hive".ibis_sf1.customer AS t5
JOIN hive.ibis_sf1.customer AS t5
ON t5.c_custkey = t4.o_custkey
JOIN "hive".ibis_sf1.nation AS t6
JOIN hive.ibis_sf1.nation AS t6
ON t2.s_nationkey = t6.n_nationkey
JOIN "hive".ibis_sf1.nation AS t7
JOIN hive.ibis_sf1.nation AS t7
ON t5.c_nationkey = t7.n_nationkey
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ WITH t0 AS (
t10.r_name AS r_name,
t7.o_orderdate AS o_orderdate,
t4.p_type AS p_type
FROM "hive".ibis_sf1.part AS t4
JOIN "hive".ibis_sf1.lineitem AS t5
FROM hive.ibis_sf1.part AS t4
JOIN hive.ibis_sf1.lineitem AS t5
ON t4.p_partkey = t5.l_partkey
JOIN "hive".ibis_sf1.supplier AS t6
JOIN hive.ibis_sf1.supplier AS t6
ON t6.s_suppkey = t5.l_suppkey
JOIN "hive".ibis_sf1.orders AS t7
JOIN hive.ibis_sf1.orders AS t7
ON t5.l_orderkey = t7.o_orderkey
JOIN "hive".ibis_sf1.customer AS t8
JOIN hive.ibis_sf1.customer AS t8
ON t7.o_custkey = t8.c_custkey
JOIN "hive".ibis_sf1.nation AS t9
JOIN hive.ibis_sf1.nation AS t9
ON t8.c_nationkey = t9.n_nationkey
JOIN "hive".ibis_sf1.region AS t10
JOIN hive.ibis_sf1.region AS t10
ON t9.n_regionkey = t10.r_regionkey
JOIN "hive".ibis_sf1.nation AS t11
JOIN hive.ibis_sf1.nation AS t11
ON t6.s_nationkey = t11.n_nationkey
), t1 AS (
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ WITH t0 AS (
CAST(EXTRACT(year FROM t6.o_orderdate) AS SMALLINT) AS o_year,
t7.n_name AS nation,
t5.p_name AS p_name
FROM "hive".ibis_sf1.lineitem AS t2
JOIN "hive".ibis_sf1.supplier AS t3
FROM hive.ibis_sf1.lineitem AS t2
JOIN hive.ibis_sf1.supplier AS t3
ON t3.s_suppkey = t2.l_suppkey
JOIN "hive".ibis_sf1.partsupp AS t4
JOIN hive.ibis_sf1.partsupp AS t4
ON t4.ps_suppkey = t2.l_suppkey AND t4.ps_partkey = t2.l_partkey
JOIN "hive".ibis_sf1.part AS t5
JOIN hive.ibis_sf1.part AS t5
ON t5.p_partkey = t2.l_partkey
JOIN "hive".ibis_sf1.orders AS t6
JOIN hive.ibis_sf1.orders AS t6
ON t6.o_orderkey = t2.l_orderkey
JOIN "hive".ibis_sf1.nation AS t7
JOIN hive.ibis_sf1.nation AS t7
ON t3.s_nationkey = t7.n_nationkey
WHERE
t5.p_name LIKE '%green%'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ WITH t0 AS (
SUM(t4.l_extendedprice * (
1 - t4.l_discount
)) AS revenue
FROM "hive".ibis_sf1.customer AS t2
JOIN "hive".ibis_sf1.orders AS t3
FROM hive.ibis_sf1.customer AS t2
JOIN hive.ibis_sf1.orders AS t3
ON t2.c_custkey = t3.o_custkey
JOIN "hive".ibis_sf1.lineitem AS t4
JOIN hive.ibis_sf1.lineitem AS t4
ON t4.l_orderkey = t3.o_orderkey
JOIN "hive".ibis_sf1.nation AS t5
JOIN hive.ibis_sf1.nation AS t5
ON t2.c_nationkey = t5.n_nationkey
WHERE
t3.o_orderdate >= FROM_ISO8601_DATE('1993-10-01')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ WITH t0 AS (
SELECT
t2.ps_partkey AS ps_partkey,
SUM(t2.ps_supplycost * t2.ps_availqty) AS value
FROM "hive".ibis_sf1.partsupp AS t2
JOIN "hive".ibis_sf1.supplier AS t3
FROM hive.ibis_sf1.partsupp AS t2
JOIN hive.ibis_sf1.supplier AS t3
ON t2.ps_suppkey = t3.s_suppkey
JOIN "hive".ibis_sf1.nation AS t4
JOIN hive.ibis_sf1.nation AS t4
ON t4.n_nationkey = t3.s_nationkey
WHERE
t4.n_name = 'GERMANY'
Expand All @@ -27,10 +27,10 @@ FROM (
FROM (
SELECT
SUM(t2.ps_supplycost * t2.ps_availqty) AS total
FROM "hive".ibis_sf1.partsupp AS t2
JOIN "hive".ibis_sf1.supplier AS t3
FROM hive.ibis_sf1.partsupp AS t2
JOIN hive.ibis_sf1.supplier AS t3
ON t2.ps_suppkey = t3.s_suppkey
JOIN "hive".ibis_sf1.nation AS t4
JOIN hive.ibis_sf1.nation AS t4
ON t4.n_nationkey = t3.s_nationkey
WHERE
t4.n_name = 'GERMANY'
Expand Down
Loading

0 comments on commit 8bde3e0

Please sign in to comment.