Skip to content

Commit

Permalink
feat(polars): add limited support for table dot sql (#8528)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Mar 20, 2024
1 parent 1a8eec8 commit b2a4fbb
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 15 deletions.
2 changes: 1 addition & 1 deletion ci/schema/exasol.sql
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ CREATE OR REPLACE TABLE EXASOL."awards_players"
"yearID" BIGINT,
"lgID" VARCHAR(256),
"tie" VARCHAR(256),
"notest" VARCHAR(256)
"notes" VARCHAR(256)
);

CREATE OR REPLACE TABLE EXASOL."functional_alltypes"
Expand Down
12 changes: 11 additions & 1 deletion ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,17 @@ def compile(
return translate(node, ctx=self._context)

def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
raise NotImplementedError("table.sql() not yet supported in polars")
import sqlglot as sg

cte = sg.parse_one(str(ibis.to_sql(table, dialect="postgres")), read="postgres")
parsed = sg.parse_one(query, read=self.dialect)
parsed.args["with"] = cte.args.pop("with", [])
parsed = parsed.with_(
sg.to_identifier(name, quoted=True), as_=cte, dialect=self.dialect
)

sql = parsed.sql(self.dialect)
return self._get_schema_using_query(sql)

def _get_schema_using_query(self, query: str) -> sch.Schema:
return PolarsSchema.to_ibis(self._context.execute(query).schema)
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,8 +1225,6 @@ def execute_arg_min(op, **kw):

@translate.register(ops.SQLStringView)
def execute_sql_string_view(op, *, ctx: pl.SQLContext, **kw):
child = translate(op.child, ctx=ctx, **kw)
ctx.register(op.name, child)
return ctx.execute(op.query)


Expand Down
22 changes: 15 additions & 7 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import parse_qs, urlparse

import psycopg2
import sqlglot as sg
import sqlglot.expressions as sge
from psycopg2 import extras

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -102,6 +100,8 @@ def _from_url(self, url: str, **kwargs):
return self.connect(**kwargs)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
from psycopg2.extras import execute_batch

schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
raise exc.IbisTypeError(
Expand Down Expand Up @@ -148,9 +148,10 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
specs = ", ".join(repeat("%s", len(columns)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})"

with self.begin() as cur:
cur.execute(create_stmt_sql)
extras.execute_batch(cur, sql, data, 128)
execute_batch(cur, sql, data, 128)

@contextlib.contextmanager
def begin(self):
Expand Down Expand Up @@ -258,7 +259,11 @@ def do_connect(
month : int32
"""
import psycopg2
import psycopg2.extras

psycopg2.extras.register_default_json(loads=lambda x: x)

self.con = psycopg2.connect(
host=host,
port=port,
Expand Down Expand Up @@ -700,6 +705,9 @@ def _safe_raw_sql(self, *args, **kwargs):
yield result

def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
import psycopg2
import psycopg2.extras

with contextlib.suppress(AttributeError):
query = query.sql(dialect=self.dialect)

Expand All @@ -709,11 +717,11 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
try:
# try to load hstore, uuid and ipaddress extensions
with contextlib.suppress(psycopg2.ProgrammingError):
extras.register_hstore(cursor)
psycopg2.extras.register_hstore(cursor)
with contextlib.suppress(psycopg2.ProgrammingError):
extras.register_uuid(conn_or_curs=cursor)
psycopg2.extras.register_uuid(conn_or_curs=cursor)
with contextlib.suppress(psycopg2.ProgrammingError):
extras.register_ipaddress(cursor)
psycopg2.extras.register_ipaddress(cursor)
except Exception:
cursor.close()
raise
Expand All @@ -740,4 +748,4 @@ def _to_sqlglot(

if conversions:
table_expr = table_expr.mutate(**conversions)
return super()._to_sqlglot(table_expr, limit=limit, params=params)
return super()._to_sqlglot(table_expr, limit=limit, params=params, **kwargs)
28 changes: 25 additions & 3 deletions ibis/backends/tests/test_dot_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def test_con_dot_sql(backend, con, schema):
backend.assert_series_equal(result.astype(expected.dtype), expected)


@pytest.mark.notyet(["polars"], raises=PolarsComputeError)
@pytest.mark.notyet(
["polars"],
raises=PolarsComputeError,
reason="polars doesn't support quoted identifiers referencing CTEs",
)
@pytest.mark.notyet(
["bigquery"], raises=GoogleBadRequest, reason="requires a qualified name"
)
Expand Down Expand Up @@ -119,7 +123,11 @@ def test_table_dot_sql(backend):
assert pytest.approx(result) == expected


@pytest.mark.notyet(["polars"], raises=PolarsComputeError)
@pytest.mark.notyet(
["polars"],
raises=PolarsComputeError,
reason="polars doesn't support quoted identifiers referencing CTEs",
)
@pytest.mark.notyet(
["bigquery"], raises=GoogleBadRequest, reason="requires a qualified name"
)
Expand Down Expand Up @@ -315,7 +323,11 @@ def mem_t(con):


@dot_sql_never
@pytest.mark.notyet(["polars"], raises=NotImplementedError)
@pytest.mark.notyet(
["polars"],
raises=PolarsComputeError,
reason="polars doesn't support selecting from quoted identifiers referencing CTEs",
)
@pytest.mark.notyet(
["druid"],
raises=KeyError,
Expand All @@ -337,3 +349,13 @@ def test_cte(alltypes, df):
)

tm.assert_frame_equal(result, expected)


@dot_sql_never
def test_bare_minimum(alltypes, df):
"""Test that a backend that supports dot sql can do the most basic thing."""

expr = alltypes.sql(
'SELECT COUNT(*) AS "n" FROM "functional_alltypes"', dialect="duckdb"
)
assert expr.to_pandas().iat[0, 0] == len(df)
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_scalar_to_pyarrow_scalar(limit, awards_players):
assert isinstance(scalar, pa.Scalar)


@pytest.mark.notimpl(["druid", "exasol"])
@pytest.mark.notimpl(["druid"])
def test_table_to_pyarrow_table_schema(awards_players):
table = awards_players.to_pyarrow()
assert isinstance(table, pa.Table)
Expand Down

0 comments on commit b2a4fbb

Please sign in to comment.