From 75af658df785529e65bf2fdeca472d3a71d1ff09 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 26 Aug 2022 14:24:02 -0400 Subject: [PATCH] fix(compiler): use `repr` for SQL string `VALUES` data --- .../base/sql/compiler/query_builder.py | 19 +++++++++----- ibis/backends/pyspark/__init__.py | 26 +++++++++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/ibis/backends/base/sql/compiler/query_builder.py b/ibis/backends/base/sql/compiler/query_builder.py index 29b04d8be661..dc26656a1e11 100644 --- a/ibis/backends/base/sql/compiler/query_builder.py +++ b/ibis/backends/base/sql/compiler/query_builder.py @@ -100,6 +100,18 @@ def _get_join_type(self, op): def _quote_identifier(self, name): return quote_identifier(name) + def _format_in_memory_table(self, op): + names = op.schema.names + raw_rows = ( + ", ".join( + f"{val!r} AS {self._quote_identifier(name)}" + for val, name in zip(row, names) + ) + for row in op.data.itertuples(index=False) + ) + rows = ", ".join(f"({raw_row})" for raw_row in raw_rows) + return f"(VALUES {rows})" + def _format_table(self, expr): # TODO: This could probably go in a class and be significantly nicer ctx = self.context @@ -117,12 +129,7 @@ def _format_table(self, expr): result = self._quote_identifier(name) is_subquery = False elif isinstance(ref_op, ops.InMemoryTable): - names = ref_op.schema.names - rows = ", ".join( - f"({', '.join(map('{} AS {}'.format, col, names))})" - for col in ref_op.data.itertuples(index=False) - ) - result = f"(VALUES {rows})" + result = self._format_in_memory_table(ref_op) is_subquery = True else: # A subquery diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 91fb0b9b9367..8cdc593e51f4 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Mapping +from typing import Any, Mapping import pandas as pd import pyspark @@ -13,10 +13,13 @@ import ibis.common.exceptions as com import ibis.config +import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as types +import ibis.expr.types as ir import ibis.util as util from ibis.backends.base.sql import BaseSQLBackend +from ibis.backends.base.sql.compiler import Compiler, TableSetFormatter from ibis.backends.base.sql.ddl import ( CreateDatabase, DropTable, @@ -33,10 +36,6 @@ from ibis.expr.scope import Scope from ibis.expr.timecontext import canonicalize_context, localize_context -if TYPE_CHECKING: - import ibis.expr.operations as ops - import ibis.expr.types as ir - _read_csv_defaults = { 'header': True, 'multiLine': True, @@ -87,7 +86,24 @@ def __exit__(self, exc_type, exc_value, traceback): """No-op for compatibility.""" +class PySparkTableSetFormatter(TableSetFormatter): + def _format_in_memory_table(self, op): + names = op.schema.names + rows = ", ".join( + f"({', '.join(map(repr, row))})" + for row in op.data.itertuples(index=False) + ) + signature = ", ".join(map(self._quote_identifier, names)) + name = self._quote_identifier(op.name or "_") + return f"(VALUES {rows} AS {name} ({signature}))" + + +class PySparkCompiler(Compiler): + table_set_formatter_class = PySparkTableSetFormatter + + class Backend(BaseSQLBackend): + compiler = PySparkCompiler name = 'pyspark' table_class = PySparkDatabaseTable table_expr_class = PySparkTable