Skip to content

Commit

Permalink
feat(duckdb/postgres/mysql/pyspark): implement .sql on tables for m…
Browse files Browse the repository at this point in the history
…ixing sql and expressions
  • Loading branch information
cpcloud committed Apr 4, 2022
1 parent a366d9c commit 00e8087
Show file tree
Hide file tree
Showing 18 changed files with 441 additions and 86 deletions.
6 changes: 6 additions & 0 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,9 @@ def has_operation(cls, operation: type[ops.ValueOp]) -> bool:
translator = cls.compiler.translator_class
op_classes = translator._registry.keys() | translator._rewrites.keys()
return operation in op_classes

def _create_temp_view(self, view, definition):
raise NotImplementedError(
f"The {self.name} backend does not implement temporary view "
"creation"
)
31 changes: 31 additions & 0 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def do_connect(self, con: sa.engine.Engine) -> None:
self._inspector = sa.inspect(self.con)
self.meta = sa.MetaData(bind=self.con)
self._schemas: dict[str, sch.Schema] = {}
self._temp_views: set[str] = set()

@property
def version(self):
Expand Down Expand Up @@ -478,3 +479,33 @@ def insert(
"is not a pandas DataFrame or is not a ibis TableExpr."
f"The given obj is of type {type(obj).__name__} ."
)

def _get_temp_view_definition(
self,
name: str,
definition: sa.sql.compiler.Compiled,
) -> str:
raise NotImplementedError(
f"The {self.name} backend does not implement temporary view "
"creation"
)

def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
pass

def _create_temp_view(
self,
view: sa.Table,
definition: sa.sql.Selectable,
) -> None:
raw_name = view.name
if raw_name not in self._temp_views and raw_name in self.list_tables():
raise ValueError(f"{raw_name} already exists as a table or view")

name = self.con.dialect.identifier_preparer.quote_identifier(raw_name)
compiled = definition.compile()
defn = self._get_temp_view_definition(name, definition=compiled)
query = sa.text(defn).bindparams(**compiled.params)
self.con.execute(query, definition)
self._temp_views.add(raw_name)
self._register_temp_view_cleanup(name, raw_name)
20 changes: 19 additions & 1 deletion ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import functools

import sqlalchemy as sa
import sqlalchemy.sql as sql

import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base.sql.compiler import (
Compiler,
Expand All @@ -18,6 +21,10 @@
from .translator import AlchemyContext, AlchemyExprTranslator


def _schema_to_sqlalchemy_columns(schema: sch.Schema) -> list[sa.Column]:
return [sa.column(n, to_sqla_type(t)) for n, t in schema.items()]


class _AlchemyTableSetFormatter(TableSetFormatter):
def get_result(self):
# Got to unravel the join stack; the nesting order could be
Expand Down Expand Up @@ -85,8 +92,19 @@ def _format_table(self, expr):
schema = ref_op.schema
result = sa.table(
ref_op.name,
*(sa.column(n, to_sqla_type(t)) for n, t in schema.items()),
*_schema_to_sqlalchemy_columns(schema),
)
elif isinstance(ref_op, ops.SQLStringView):
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
result = sa.text(ref_op.query).columns(*columns).cte(ref_op.name)
elif isinstance(ref_op, ops.View):
definition = ref_op.child.compile()
result = sa.table(
ref_op.name,
*_schema_to_sqlalchemy_columns(ref_op.schema),
)
backend = ref_op.child._find_backend()
backend._create_temp_view(view=result, definition=definition)
else:
# A subquery
if ctx.is_extracted(ref_expr):
Expand Down
12 changes: 8 additions & 4 deletions ibis/backends/base/sql/compiler/extract_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,21 @@ def visit_Difference(self, expr):
self.visit(op.right)
self.observe(expr)

def visit_MaterializedJoin(self, expr):
self.visit(expr.op().join)
self.observe(expr)

def visit_Selection(self, expr):
self.visit(expr.op().table)
self.observe(expr)

def visit_SQLQueryResult(self, expr):
self.observe(expr)

def visit_View(self, expr):
self.visit(expr.op().child)
self.observe(expr)

def visit_SQLStringView(self, expr):
self.visit(expr.op().child)
self.observe(expr)

def visit_TableColumn(self, expr):
table = expr.op().table
if not self.seen(table):
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_compiled_expr(self, expr):
pass

op = expr.op()
if isinstance(op, ops.SQLQueryResult):
if isinstance(op, (ops.SQLQueryResult, ops.SQLStringView)):
result = op.query
else:
result = self._compile_subquery(expr)
Expand Down
24 changes: 20 additions & 4 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,11 @@ def pytest_runtest_call(item):
def backend(request, data_directory):
"""Return an instance of BackendTest."""
cls = _get_backend_conf(request.param)
return cls(data_directory)
result = cls(data_directory)
try:
yield result
finally:
result.cleanup()


@pytest.fixture(scope='session')
Expand All @@ -286,7 +290,11 @@ def ddl_backend(request, data_directory):
(sqlite, postgres, mysql, datafusion, clickhouse, pyspark, impala)
"""
cls = _get_backend_conf(request.param)
return cls(data_directory)
result = cls(data_directory)
try:
yield result
finally:
result.cleanup()


@pytest.fixture(scope='session')
Expand Down Expand Up @@ -315,7 +323,11 @@ def alchemy_backend(request, data_directory):
)
else:
cls = _get_backend_conf(request.param)
return cls(data_directory)
result = cls(data_directory)
try:
yield result
finally:
result.cleanup()


@pytest.fixture(scope='session')
Expand All @@ -335,7 +347,11 @@ def udf_backend(request, data_directory):
Runs the UDF-supporting backends
"""
cls = _get_backend_conf(request.param)
return cls(data_directory)
result = cls(data_directory)
try:
yield result
finally:
result.cleanup()


@pytest.fixture(scope='session')
Expand Down
15 changes: 14 additions & 1 deletion ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend

from .compiler import DuckDBSQLCompiler
from .datatypes import parse_type


class Backend(BaseAlchemyBackend):
Expand Down Expand Up @@ -69,4 +70,16 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a SQL string."""
with self.con.connect() as con:
rel = con.connection.c.query(query)
return sch.infer(rel)
return sch.Schema.from_dict(
{
name: parse_type(type)
for name, type in zip(rel.columns, rel.types)
}
)

def _get_temp_view_definition(
self,
name: str,
definition: sa.sql.compiler.Compiled,
) -> str:
return f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"
17 changes: 17 additions & 0 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import atexit
import contextlib
import warnings
from typing import Literal
Expand Down Expand Up @@ -133,6 +134,22 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
]
return sch.Schema.from_tuples(fields)

def _get_temp_view_definition(
self,
name: str,
definition: sa.sql.compiler.Compiled,
) -> str:
return f"CREATE OR REPLACE VIEW {name} AS {definition}"

def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
query = f"DROP VIEW IF EXISTS {name}"

def drop(self, raw_name: str, query: str):
self.con.execute(query)
self._temp_views.discard(raw_name)

atexit.register(drop, self, raw_name, query)


# TODO(kszucs): unsigned integers

Expand Down
77 changes: 7 additions & 70 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@

import sqlalchemy as sa

import ibis.backends.duckdb.datatypes as ddb
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis import util
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend

from .compiler import PostgreSQLCompiler
from .datatypes import _get_type
from .udf import udf


Expand Down Expand Up @@ -205,71 +204,9 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
tuples = [(col, _get_type(typestr)) for col, typestr in type_info]
return sch.Schema.from_tuples(tuples)


def _get_type(typestr: str) -> dt.DataType:
try:
return _type_mapping[typestr]
except KeyError:
return ddb.parse_type(typestr)


_type_mapping = {
"boolean": dt.bool,
"boolean[]": dt.Array(dt.bool),
"bytea": dt.binary,
"bytea[]": dt.Array(dt.binary),
"character(1)": dt.string,
"character(1)[]": dt.Array(dt.string),
"bigint": dt.int64,
"bigint[]": dt.Array(dt.int64),
"smallint": dt.int16,
"smallint[]": dt.Array(dt.int16),
"integer": dt.int32,
"integer[]": dt.Array(dt.int32),
"text": dt.string,
"text[]": dt.Array(dt.string),
"json": dt.json,
"json[]": dt.Array(dt.json),
"point": dt.point,
"point[]": dt.Array(dt.point),
"polygon": dt.polygon,
"polygon[]": dt.Array(dt.polygon),
"line": dt.linestring,
"line[]": dt.Array(dt.linestring),
"real": dt.float32,
"real[]": dt.Array(dt.float32),
"double precision": dt.float64,
"double precision[]": dt.Array(dt.float64),
"macaddr8": dt.macaddr,
"macaddr8[]": dt.Array(dt.macaddr),
"macaddr": dt.macaddr,
"macaddr[]": dt.Array(dt.macaddr),
"inet": dt.inet,
"inet[]": dt.Array(dt.inet),
"character": dt.string,
"character[]": dt.Array(dt.string),
"character varying": dt.string,
"character varying[]": dt.Array(dt.string),
"date": dt.date,
"date[]": dt.Array(dt.date),
"time without time zone": dt.time,
"time without time zone[]": dt.Array(dt.time),
"timestamp without time zone": dt.timestamp,
"timestamp without time zone[]": dt.Array(dt.timestamp),
"timestamp with time zone": dt.Timestamp("UTC"),
"timestamp with time zone[]": dt.Array(dt.Timestamp("UTC")),
"interval": dt.interval,
"interval[]": dt.Array(dt.interval),
# NB: this isn"t correct, but we try not to fail
"time with time zone": "time",
"numeric": dt.decimal,
"numeric[]": dt.Array(dt.decimal),
"uuid": dt.uuid,
"uuid[]": dt.Array(dt.uuid),
"jsonb": dt.jsonb,
"jsonb[]": dt.Array(dt.jsonb),
"geometry": dt.geometry,
"geometry[]": dt.Array(dt.geometry),
"geography": dt.geography,
"geography[]": dt.Array(dt.geography),
}
def _get_temp_view_definition(
self,
name: str,
definition: sa.sql.compiler.Compiled,
) -> str:
return f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"
Loading

0 comments on commit 00e8087

Please sign in to comment.