Skip to content

Commit

Permalink
fix(snowflake): make sure ephemeral tables following backend quoting …
Browse files Browse the repository at this point in the history
…rules
  • Loading branch information
cpcloud authored and kszucs committed Apr 5, 2023
1 parent 4f1d9fe commit 9a845df
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 52 deletions.
24 changes: 16 additions & 8 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class BaseAlchemyBackend(BaseSQLBackend):
database_class = AlchemyDatabase
table_class = AlchemyTable
compiler = AlchemyCompiler
quote_table_names = None

def _build_alchemy_url(self, url, host, port, user, password, database, driver):
if url is not None:
Expand Down Expand Up @@ -278,7 +277,7 @@ def _columns_from_schema(self, name: str, schema: sch.Schema) -> list[sa.Column]
colname,
to_sqla_type(dialect, dtype),
nullable=dtype.nullable,
quote=self.compiler.translator_class._always_quote_columns,
quote=self.compiler.translator_class._quote_column_names,
)
for colname, dtype in zip(schema.names, schema.types)
]
Expand All @@ -295,7 +294,7 @@ def _table_from_schema(
sa.MetaData(),
*columns,
prefixes=prefixes,
quote=self.quote_table_names,
quote=self.compiler.translator_class._quote_table_names,
)

def drop_table(
Expand Down Expand Up @@ -425,7 +424,7 @@ def _get_sqla_table(
sa.MetaData(),
schema=schema,
autoload_with=self.con if autoload else None,
quote=self.quote_table_names,
quote=self.compiler.translator_class._quote_table_names,
)
nulltype_cols = frozenset(
col.name for col in table.c if isinstance(col.type, sa.types.NullType)
Expand Down Expand Up @@ -453,7 +452,7 @@ def _handle_failed_column_type_inference(
colname,
to_sqla_type(dialect, type),
nullable=type.nullable,
quote=self.compiler.translator_class._always_quote_columns,
quote=self.compiler.translator_class._quote_column_names,
),
replace_existing=True,
)
Expand Down Expand Up @@ -620,7 +619,10 @@ def insert(

def _quote(self, name: str) -> str:
"""Quote an identifier."""
return self.con.dialect.identifier_preparer.quote(name)
preparer = self.con.dialect.identifier_preparer
if self.compiler.translator_class._quote_table_names:
return preparer.quote_identifier(name)
return preparer.quote(name)

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
Expand Down Expand Up @@ -692,7 +694,10 @@ def create_view(
source = self.compile(obj)
view = sav.CreateView(
sa.Table(
name, sa.MetaData(), schema=database, quote=self.quote_table_names
name,
sa.MetaData(),
schema=database,
quote=self.compiler.translator_class._quote_table_names,
),
source,
or_replace=overwrite,
Expand All @@ -708,7 +713,10 @@ def drop_view(

view = sav.DropView(
sa.Table(
name, sa.MetaData(), schema=database, quote=self.quote_table_names
name,
sa.MetaData(),
schema=database,
quote=self.compiler.translator_class._quote_table_names,
),
if_exists=not force,
)
Expand Down
21 changes: 16 additions & 5 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def _format_table(self, op):
result = ref_op.sqla_table
elif isinstance(ref_op, ops.UnboundTable):
# use SQLAlchemy's TableClause for unbound tables
result = sa.table(
ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema)
result = sa.Table(
ref_op.name,
sa.MetaData(),
*translator._schema_to_sqlalchemy_columns(ref_op.schema),
quote=translator._quote_table_names,
)
elif isinstance(ref_op, ops.SQLQueryResult):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
Expand All @@ -109,8 +112,11 @@ def _format_table(self, op):
# TODO(kszucs): avoid converting to expression
child_expr = ref_op.child.to_expr()
definition = child_expr.compile()
result = sa.table(
ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema)
result = sa.Table(
ref_op.name,
sa.MetaData(),
*translator._schema_to_sqlalchemy_columns(ref_op.schema),
quote=translator._quote_table_names,
)
backend = child_expr._find_backend()
backend._create_temp_view(view=result, definition=definition)
Expand Down Expand Up @@ -141,7 +147,12 @@ def _format_table(self, op):
def _format_in_memory_table(self, op, ref_op, translator):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
if self.context.compiler.cheap_in_memory_tables:
result = sa.table(ref_op.name, *columns)
result = sa.Table(
ref_op.name,
sa.MetaData(),
*columns,
quote=translator._quote_table_names,
)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(ref_op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=ref_op.name).data(rows)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _table_column(t, op):
sa_table = get_sqla_table(ctx, table)

out_expr = get_col(sa_table, op)
out_expr.quote = t._always_quote_columns
out_expr.quote = t._quote_column_names

# If the column does not originate from the table set in the current SELECT
# context, we should format as a subquery
Expand Down
12 changes: 8 additions & 4 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def integer_to_timestamp(self, arg, tz: str | None = None):
)

native_json_type = True
_always_quote_columns = None # let the dialect decide how to quote
_quote_column_names = None # let the dialect decide how to quote
_quote_table_names = None

_require_order_by = (
ops.DenseRank,
Expand All @@ -77,11 +78,14 @@ def dialect(self) -> sa.engine.interfaces.Dialect:

def _schema_to_sqlalchemy_columns(self, schema):
return [
sa.column(name, self.get_sqla_type(dtype)) for name, dtype in schema.items()
sa.Column(name, self.get_sqla_type(dtype), quote=self._quote_column_names)
for name, dtype in schema.items()
]

def name(self, translated, name, force=True):
return translated.label(name)
def name(self, translated, name, force=False):
return translated.label(
sa.sql.quoted_name(name, quote=force or self._quote_column_names)
)

def get_sqla_type(self, data_type):
return to_sqla_type(self.dialect, data_type)
Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class SnowflakeExprTranslator(AlchemyExprTranslator):
)
_require_order_by = (*AlchemyExprTranslator._require_order_by, ops.Reduction)
_dialect_name = "snowflake"
_always_quote_columns = True
_quote_column_names = True
_quote_table_names = True
supports_unnest_in_select = False


Expand Down Expand Up @@ -144,7 +145,6 @@ def _make_udf(name, defn, *, quote) -> str:
class Backend(BaseAlchemyBackend):
name = "snowflake"
compiler = SnowflakeCompiler
quote_table_names = True

@property
def _current_schema(self) -> str:
Expand Down Expand Up @@ -368,14 +368,15 @@ def list_databases(self, like=None) -> list[str]:
return self._filter_with_like(databases, like)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
df = op.data.to_frame()
with self.begin() as con:
write_pandas(
conn=con.connection.connection,
df=op.data.to_frame(),
df=df,
table_name=op.name,
table_type="temp",
auto_create_table=True,
quote_identifiers=False,
quote_identifiers=True,
)

def _get_temp_view_definition(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
WITH t0 AS (
SELECT
t5.street AS street,
ROW_NUMBER() OVER (ORDER BY t5.street) - 1 AS key
FROM data AS t5
t5."street" AS "street",
ROW_NUMBER() OVER (ORDER BY t5."street") - 1 AS "key"
FROM "data" AS t5
), t1 AS (
SELECT
t0.key AS key
t0."key" AS "key"
FROM t0
), t2 AS (
SELECT
t0.street AS street,
t0.key AS key
t0."street" AS "street",
t0."key" AS "key"
FROM t0
JOIN t1
ON t0.key = t1.key
ON t0."key" = t1."key"
), t3 AS (
SELECT
t2.street AS street,
ROW_NUMBER() OVER (ORDER BY t2.street) - 1 AS key
t2."street" AS "street",
ROW_NUMBER() OVER (ORDER BY t2."street") - 1 AS "key"
FROM t2
), t4 AS (
SELECT
t3.key AS key
t3."key" AS "key"
FROM t3
)
SELECT
t3.street,
t3.key
t3."street",
t3."key"
FROM t3
JOIN t4
ON t3.key = t4.key
ON t3."key" = t4."key"
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
WITH t0 AS (
SELECT
t4.key AS key
FROM leaf AS t4
t4."key" AS "key"
FROM "leaf" AS t4
WHERE
TRUE
), t1 AS (
SELECT
t0.key AS key
t0."key" AS "key"
FROM t0
), t2 AS (
SELECT
t0.key AS key
t0."key" AS "key"
FROM t0
JOIN t1
ON t0.key = t1.key
ON t0."key" = t1."key"
)
SELECT
t2.key
t2."key"
FROM t2
JOIN t2 AS t3
ON t2.key = t3.key
ON t2."key" = t3."key"
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
CASE t0.continent
CASE t0."continent"
WHEN 'NA'
THEN 'North America'
WHEN 'SA'
Expand All @@ -15,8 +15,8 @@ SELECT
WHEN 'AN'
THEN 'Antarctica'
ELSE 'Unknown continent'
END AS cont,
SUM(t0.population) AS total_pop
FROM countries AS t0
END AS "cont",
SUM(t0."population") AS "total_pop"
FROM "countries" AS t0
GROUP BY
1
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
SELECT
t0.x IN (
t0."x" IN (
SELECT
t1.x
t1."x"
FROM (
SELECT
t0.x AS x
FROM t AS t0
t0."x" AS "x"
FROM "t" AS t0
WHERE
t0.x > 2
t0."x" > 2
) AS t1
) AS "Contains(x, x)"
FROM t AS t0
FROM "t" AS t0

0 comments on commit 9a845df

Please sign in to comment.