From 9a845df3e0c473c129c94c45b3683aaf68863410 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 5 Apr 2023 08:02:59 -0400 Subject: [PATCH] fix(snowflake): make sure ephemeral tables following backend quoting rules --- ibis/backends/base/sql/alchemy/__init__.py | 24 +++++++++++------ .../base/sql/alchemy/query_builder.py | 21 +++++++++++---- ibis/backends/base/sql/alchemy/registry.py | 2 +- ibis/backends/base/sql/alchemy/translator.py | 12 ++++++--- ibis/backends/snowflake/__init__.py | 9 ++++--- .../test_many_subqueries/snowflake/out.sql | 26 +++++++++---------- .../snowflake/out.sql | 14 +++++----- .../test_group_by_has_index/snowflake/out.sql | 8 +++--- .../test_sql/test_isin_bug/snowflake/out.sql | 12 ++++----- 9 files changed, 76 insertions(+), 52 deletions(-) diff --git a/ibis/backends/base/sql/alchemy/__init__.py b/ibis/backends/base/sql/alchemy/__init__.py index e3f67f1c084a..385e3ea11048 100644 --- a/ibis/backends/base/sql/alchemy/__init__.py +++ b/ibis/backends/base/sql/alchemy/__init__.py @@ -69,7 +69,6 @@ class BaseAlchemyBackend(BaseSQLBackend): database_class = AlchemyDatabase table_class = AlchemyTable compiler = AlchemyCompiler - quote_table_names = None def _build_alchemy_url(self, url, host, port, user, password, database, driver): if url is not None: @@ -278,7 +277,7 @@ def _columns_from_schema(self, name: str, schema: sch.Schema) -> list[sa.Column] colname, to_sqla_type(dialect, dtype), nullable=dtype.nullable, - quote=self.compiler.translator_class._always_quote_columns, + quote=self.compiler.translator_class._quote_column_names, ) for colname, dtype in zip(schema.names, schema.types) ] @@ -295,7 +294,7 @@ def _table_from_schema( sa.MetaData(), *columns, prefixes=prefixes, - quote=self.quote_table_names, + quote=self.compiler.translator_class._quote_table_names, ) def drop_table( @@ -425,7 +424,7 @@ def _get_sqla_table( sa.MetaData(), schema=schema, autoload_with=self.con if autoload else None, - quote=self.quote_table_names, + quote=self.compiler.translator_class._quote_table_names, ) nulltype_cols = frozenset( col.name for col in table.c if isinstance(col.type, sa.types.NullType) @@ -453,7 +452,7 @@ def _handle_failed_column_type_inference( colname, to_sqla_type(dialect, type), nullable=type.nullable, - quote=self.compiler.translator_class._always_quote_columns, + quote=self.compiler.translator_class._quote_column_names, ), replace_existing=True, ) @@ -620,7 +619,10 @@ def insert( def _quote(self, name: str) -> str: """Quote an identifier.""" - return self.con.dialect.identifier_preparer.quote(name) + preparer = self.con.dialect.identifier_preparer + if self.compiler.translator_class._quote_table_names: + return preparer.quote_identifier(name) + return preparer.quote(name) def _get_temp_view_definition( self, name: str, definition: sa.sql.compiler.Compiled @@ -692,7 +694,10 @@ def create_view( source = self.compile(obj) view = sav.CreateView( sa.Table( - name, sa.MetaData(), schema=database, quote=self.quote_table_names + name, + sa.MetaData(), + schema=database, + quote=self.compiler.translator_class._quote_table_names, ), source, or_replace=overwrite, @@ -708,7 +713,10 @@ def drop_view( view = sav.DropView( sa.Table( - name, sa.MetaData(), schema=database, quote=self.quote_table_names + name, + sa.MetaData(), + schema=database, + quote=self.compiler.translator_class._quote_table_names, ), if_exists=not force, ) diff --git a/ibis/backends/base/sql/alchemy/query_builder.py b/ibis/backends/base/sql/alchemy/query_builder.py index 1fe31146ec9f..b4b4c0092250 100644 --- a/ibis/backends/base/sql/alchemy/query_builder.py +++ b/ibis/backends/base/sql/alchemy/query_builder.py @@ -96,8 +96,11 @@ def _format_table(self, op): result = ref_op.sqla_table elif isinstance(ref_op, ops.UnboundTable): # use SQLAlchemy's TableClause for unbound tables - result = sa.table( - ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema) + result = sa.Table( + ref_op.name, + sa.MetaData(), + *translator._schema_to_sqlalchemy_columns(ref_op.schema), + quote=translator._quote_table_names, ) elif isinstance(ref_op, ops.SQLQueryResult): columns = translator._schema_to_sqlalchemy_columns(ref_op.schema) @@ -109,8 +112,11 @@ def _format_table(self, op): # TODO(kszucs): avoid converting to expression child_expr = ref_op.child.to_expr() definition = child_expr.compile() - result = sa.table( - ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema) + result = sa.Table( + ref_op.name, + sa.MetaData(), + *translator._schema_to_sqlalchemy_columns(ref_op.schema), + quote=translator._quote_table_names, ) backend = child_expr._find_backend() backend._create_temp_view(view=result, definition=definition) @@ -141,7 +147,12 @@ def _format_table(self, op): def _format_in_memory_table(self, op, ref_op, translator): columns = translator._schema_to_sqlalchemy_columns(ref_op.schema) if self.context.compiler.cheap_in_memory_tables: - result = sa.table(ref_op.name, *columns) + result = sa.Table( + ref_op.name, + sa.MetaData(), + *columns, + quote=translator._quote_table_names, + ) elif self.context.compiler.support_values_syntax_in_select: rows = list(ref_op.data.to_frame().itertuples(index=False)) result = sa.values(*columns, name=ref_op.name).data(rows) diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index 174b1838ddd8..00a1a73a4d57 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -117,7 +117,7 @@ def _table_column(t, op): sa_table = get_sqla_table(ctx, table) out_expr = get_col(sa_table, op) - out_expr.quote = t._always_quote_columns + out_expr.quote = t._quote_column_names # If the column does not originate from the table set in the current SELECT # context, we should format as a subquery diff --git a/ibis/backends/base/sql/alchemy/translator.py b/ibis/backends/base/sql/alchemy/translator.py index af013ba379d1..f761272c54de 100644 --- a/ibis/backends/base/sql/alchemy/translator.py +++ b/ibis/backends/base/sql/alchemy/translator.py @@ -54,7 +54,8 @@ def integer_to_timestamp(self, arg, tz: str | None = None): ) native_json_type = True - _always_quote_columns = None # let the dialect decide how to quote + _quote_column_names = None # let the dialect decide how to quote + _quote_table_names = None _require_order_by = ( ops.DenseRank, @@ -77,11 +78,14 @@ def dialect(self) -> sa.engine.interfaces.Dialect: def _schema_to_sqlalchemy_columns(self, schema): return [ - sa.column(name, self.get_sqla_type(dtype)) for name, dtype in schema.items() + sa.Column(name, self.get_sqla_type(dtype), quote=self._quote_column_names) + for name, dtype in schema.items() ] - def name(self, translated, name, force=True): - return translated.label(name) + def name(self, translated, name, force=False): + return translated.label( + sa.sql.quoted_name(name, quote=force or self._quote_column_names) + ) def get_sqla_type(self, data_type): return to_sqla_type(self.dialect, data_type) diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 114235043abd..d8f498481b83 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -79,7 +79,8 @@ class SnowflakeExprTranslator(AlchemyExprTranslator): ) _require_order_by = (*AlchemyExprTranslator._require_order_by, ops.Reduction) _dialect_name = "snowflake" - _always_quote_columns = True + _quote_column_names = True + _quote_table_names = True supports_unnest_in_select = False @@ -144,7 +145,6 @@ def _make_udf(name, defn, *, quote) -> str: class Backend(BaseAlchemyBackend): name = "snowflake" compiler = SnowflakeCompiler - quote_table_names = True @property def _current_schema(self) -> str: @@ -368,14 +368,15 @@ def list_databases(self, like=None) -> list[str]: return self._filter_with_like(databases, like) def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + df = op.data.to_frame() with self.begin() as con: write_pandas( conn=con.connection.connection, - df=op.data.to_frame(), + df=df, table_name=op.name, table_type="temp", auto_create_table=True, - quote_identifiers=False, + quote_identifiers=True, ) def _get_temp_view_definition( diff --git a/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/snowflake/out.sql b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/snowflake/out.sql index 1367bba17f3e..bef472c14527 100644 --- a/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/snowflake/out.sql +++ b/ibis/backends/tests/snapshots/test_generic/test_many_subqueries/snowflake/out.sql @@ -1,32 +1,32 @@ WITH t0 AS ( SELECT - t5.street AS street, - ROW_NUMBER() OVER (ORDER BY t5.street) - 1 AS key - FROM data AS t5 + t5."street" AS "street", + ROW_NUMBER() OVER (ORDER BY t5."street") - 1 AS "key" + FROM "data" AS t5 ), t1 AS ( SELECT - t0.key AS key + t0."key" AS "key" FROM t0 ), t2 AS ( SELECT - t0.street AS street, - t0.key AS key + t0."street" AS "street", + t0."key" AS "key" FROM t0 JOIN t1 - ON t0.key = t1.key + ON t0."key" = t1."key" ), t3 AS ( SELECT - t2.street AS street, - ROW_NUMBER() OVER (ORDER BY t2.street) - 1 AS key + t2."street" AS "street", + ROW_NUMBER() OVER (ORDER BY t2."street") - 1 AS "key" FROM t2 ), t4 AS ( SELECT - t3.key AS key + t3."key" AS "key" FROM t3 ) SELECT - t3.street, - t3.key + t3."street", + t3."key" FROM t3 JOIN t4 - ON t3.key = t4.key \ No newline at end of file + ON t3."key" = t4."key" \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/snowflake/out.sql b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/snowflake/out.sql index 8d5d47b6920b..60738db25e2d 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/snowflake/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_cte_refs_in_topo_order/snowflake/out.sql @@ -1,22 +1,22 @@ WITH t0 AS ( SELECT - t4.key AS key - FROM leaf AS t4 + t4."key" AS "key" + FROM "leaf" AS t4 WHERE TRUE ), t1 AS ( SELECT - t0.key AS key + t0."key" AS "key" FROM t0 ), t2 AS ( SELECT - t0.key AS key + t0."key" AS "key" FROM t0 JOIN t1 - ON t0.key = t1.key + ON t0."key" = t1."key" ) SELECT - t2.key + t2."key" FROM t2 JOIN t2 AS t3 - ON t2.key = t3.key \ No newline at end of file + ON t2."key" = t3."key" \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/snowflake/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/snowflake/out.sql index fc16f2428d16..922316952999 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/snowflake/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/snowflake/out.sql @@ -1,5 +1,5 @@ SELECT - CASE t0.continent + CASE t0."continent" WHEN 'NA' THEN 'North America' WHEN 'SA' @@ -15,8 +15,8 @@ SELECT WHEN 'AN' THEN 'Antarctica' ELSE 'Unknown continent' - END AS cont, - SUM(t0.population) AS total_pop -FROM countries AS t0 + END AS "cont", + SUM(t0."population") AS "total_pop" +FROM "countries" AS t0 GROUP BY 1 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/snowflake/out.sql b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/snowflake/out.sql index d5e7138fcb23..4f7ecb5df691 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/snowflake/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/snowflake/out.sql @@ -1,13 +1,13 @@ SELECT - t0.x IN ( + t0."x" IN ( SELECT - t1.x + t1."x" FROM ( SELECT - t0.x AS x - FROM t AS t0 + t0."x" AS "x" + FROM "t" AS t0 WHERE - t0.x > 2 + t0."x" > 2 ) AS t1 ) AS "Contains(x, x)" -FROM t AS t0 \ No newline at end of file +FROM "t" AS t0 \ No newline at end of file