Skip to content

Commit

Permalink
refactor(duckdb): initial cut of sqlglot DuckDB compiler
Browse files Browse the repository at this point in the history
it's alive!

tests run (and fail)

chore(duckdb): naive port of clickhouse compiler

fix(duckdb): hacky fix for output shape

feat(duckdb): bitwise ops (most of them)

feat(duckdb): handle pandas dtype mapping in execute

feat(duckdb): handle decimal types

feat(duckdb): add euler's number

test(duckdb): remove duckdb from alchemycon

feat(duckdb): get _most_ of string ops working

still some failures in re_exract

feat(duckdb): add hash

feat(duckdb): add CAST

feat(duckdb): add cot and strright

chore(duckdb): mark all the targets that still need attention (at least)

feat(duckdb): combine binary bitwise ops

chore(datestuff): some datetime ops

feat(duckdb): add levenshtein, use op.dtype instead of output_dtype

feat(duckdb): add blank list_schemas, use old current_database for now

feat(duckdb): basic interval ops

feat(duckdb): timestamp and temporal ops

feat(duckdb): use pyarrow for fetching execute results

feat(duckdb): handle interval casts, broken for columns

feat(duckdb): shove literal handling up top

feat(duckdb): more timestamp ops

feat(duckdb): back to pandas output in execute

feat(duckdb): timezone handling in cast

feat(duckdb): ms and us epoch timestamp support

chore(duckdb): misc cleanup

feat(duckdb): initial create table

feat(duckdb): add _from_url

feat(duckdb): add read_parquet

feat(duckdb): add persistent cache

fix(duckdb): actually insert data if present in create_table

feat(duckdb): use duckdb API read_parquet

feat(duckdb): add read_csv

This, frustratingly, cannot use the Python API for `read_csv` since that
does not support list of files, for some reason.

fix(duckdb): dont fully qualify the table names

chore(duckdb): cleanup

chore(duckdb): mark broken test broken

fix(duckdb): fix read_parquet so it works

feat(duckdb): add to_pyarrow, to_pyarrow_batches, sql()

feat(duckdb): null checking

feat(duckdb): translate uints

fix(duckdb): fix file outputs and torch output

fix(duckdb): add rest of integer types

fix(duckdb): ops.InValues

feat(duckdb): use sqlglot expressions (maybe a big mistake)

fix(duckdb): don't stringify strings

feat(duckdb): use sqlglot expr instead of strings for count

fix(duckdb): fix isin

fix(duckdb): fix some agg variance functions

fix(duckdb): for logical equals, use sqlglot not operator

fix(duckdb): struct not tuple for struct type
  • Loading branch information
gforsyth authored and kszucs committed Feb 12, 2024
1 parent 6f7f190 commit ca95204
Show file tree
Hide file tree
Showing 326 changed files with 10,943 additions and 6,146 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,9 +1217,9 @@ def _cached(self, expr: ir.Table):
if (result := self._query_cache.get(op)) is None:
self._query_cache.store(expr)
result = self._query_cache[op]
return ir.CachedTable(result)
return ir.CachedTableExpr(result)

def _release_cached(self, expr: ir.CachedTable) -> None:
def _release_cached(self, expr: ir.CachedTableExpr) -> None:
"""Releases the provided cached expression.
Parameters
Expand Down
20 changes: 10 additions & 10 deletions ibis/backends/base/df/timecontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,19 @@ def adjust_context_alias(
return adjust_context(op.arg, scope, timecontext)


@adjust_context.register(ops.AsOfJoin)
def adjust_context_asof_join(
op: ops.AsOfJoin, scope: Scope, timecontext: TimeContext
) -> TimeContext:
begin, end = timecontext
# @adjust_context.register(ops.AsOfJoin)
# def adjust_context_asof_join(
# op: ops.AsOfJoin, scope: Scope, timecontext: TimeContext
# ) -> TimeContext:
# begin, end = timecontext

if op.tolerance is not None:
from ibis.backends.pandas.execution import execute
# if op.tolerance is not None:
# from ibis.backends.pandas.execution import execute

timedelta = execute(op.tolerance)
return (begin - timedelta, end)
# timedelta = execute(op.tolerance)
# return (begin - timedelta, end)

return timecontext
# return timecontext


@adjust_context.register(ops.WindowFunction)
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,7 @@ class array_filter(FunctionElement):
ops.Literal: _literal,
ops.SimpleCase: _simple_case,
ops.SearchedCase: _searched_case,
ops.TableColumn: _table_column,
ops.TableArrayView: _table_array_view,
ops.Field: _table_column,
ops.ExistsSubquery: _exists_subquery,
# miscellaneous varargs
ops.Least: varargs(sa.func.least),
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,7 @@ def _floor(t, op):
ops.InColumn: binary_infix.in_column,
ops.SimpleCase: case.simple_case,
ops.SearchedCase: case.searched_case,
ops.TableColumn: table_column,
ops.TableArrayView: table_array_view,
ops.Field: table_column,
ops.DateAdd: timestamp.timestamp_op("date_add"),
ops.DateSub: timestamp.timestamp_op("date_sub"),
ops.DateDiff: timestamp.timestamp_op("datediff"),
Expand Down
328 changes: 238 additions & 90 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,245 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any, Callable
import abc
from typing import TYPE_CHECKING, Any, ClassVar

import sqlglot as sg
import sqlglot.expressions as sge

if TYPE_CHECKING:
import ibis.expr.datatypes as dt
from ibis.backends.base.sqlglot.datatypes import SqlglotType


class AggGen:
__slots__ = ("aggfunc",)

def __init__(self, *, aggfunc: Callable) -> None:
self.aggfunc = aggfunc

def __getattr__(self, name: str) -> partial:
return partial(self.aggfunc, name)

def __getitem__(self, key: str) -> partial:
return getattr(self, key)


def _func(name: str, *args: Any, **kwargs: Any):
return sg.func(name, *map(sg.exp.convert, args), **kwargs)


class FuncGen:
__slots__ = ()

def __getattr__(self, name: str) -> partial:
return partial(_func, name)

def __getitem__(self, key: str) -> partial:
return getattr(self, key)

def array(self, *args):
return sg.exp.Array.from_arg_list(list(map(sg.exp.convert, args)))

def tuple(self, *args):
return sg.func("tuple", *map(sg.exp.convert, args))

def exists(self, query):
return sg.exp.Exists(this=query)

def concat(self, *args):
return sg.exp.Concat(expressions=list(map(sg.exp.convert, args)))
import ibis
import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis.backends.base import BaseBackend
from ibis.backends.base.sqlglot.compiler import STAR

def map(self, keys, values):
return sg.exp.Map(keys=keys, values=values)


class ColGen:
__slots__ = ()

def __getattr__(self, name: str) -> sg.exp.Column:
return sg.column(name)

def __getitem__(self, key: str) -> sg.exp.Column:
return sg.column(key)


def paren(expr):
"""Wrap a sqlglot expression in parentheses."""
return sg.exp.Paren(this=expr)


def parenthesize(op, arg):
import ibis.expr.operations as ops

if isinstance(op, (ops.Binary, ops.Unary)):
return paren(arg)
# function calls don't need parens
return arg


def interval(value, *, unit):
return sg.exp.Interval(this=sg.exp.convert(value), unit=sg.exp.var(unit))


C = ColGen()
F = FuncGen()
NULL = sg.exp.Null()
FALSE = sg.exp.false()
TRUE = sg.exp.true()
STAR = sg.exp.Star()


def make_cast(
converter: SqlglotType,
) -> Callable[[sg.exp.Expression, dt.DataType], sg.exp.Cast]:
def cast(arg: sg.exp.Expression, to: dt.DataType) -> sg.exp.Cast:
return sg.cast(arg, to=converter.from_ibis(to))
if TYPE_CHECKING:
from collections.abc import Iterator

return cast
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis.backends.base.sqlglot.compiler import SQLGlotCompiler
from ibis.common.typing import SupportsSchema


class SQLGlotBackend(BaseBackend):
compiler: ClassVar[SQLGlotCompiler]
name: ClassVar[str]

@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
# singledispatchmethod overrides `__get__` so we can't directly access
# the dispatcher
dispatcher = cls.compiler.visit_node.register.__self__.dispatcher
return dispatcher.dispatch(operation) is not dispatcher.dispatch(object)

def _transform(
self, sql: sge.Expression, table_expr: ir.TableExpr
) -> sge.Expression:
return sql

def table(
self, name: str, schema: str | None = None, database: str | None = None
) -> ir.Table:
"""Construct a table expression.
Parameters
----------
name
Table name
schema
Schema name
database
Database name
Returns
-------
Table
Table expression
"""
table_schema = self.get_schema(name, schema=schema, database=database)
return ops.DatabaseTable(
name,
schema=table_schema,
source=self,
namespace=ops.Namespace(database=database, schema=schema),
).to_expr()

def _to_sqlglot(
self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any
):
"""Compile an Ibis expression to a sqlglot object."""
table_expr = expr.as_table()

if limit == "default":
limit = ibis.options.sql.default_limit
if limit is not None:
table_expr = table_expr.limit(limit)

if params is None:
params = {}

sql = self.compiler.translate(table_expr.op(), params=params)
assert not isinstance(sql, sge.Subquery)

if isinstance(sql, sge.Table):
sql = sg.select(STAR).from_(sql)

assert not isinstance(sql, sge.Subquery)
return [self._transform(sql, table_expr)]

def compile(
self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any
):
"""Compile an Ibis expression to a ClickHouse SQL string."""
queries = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)

return ";\n\n".join(
query.sql(dialect=self.name, pretty=True) for query in queries
)

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
return self.compile(expr, **kwargs)

def _log(self, sql: str) -> None:
"""Log `sql`.
This method can be implemented by subclasses. Logging occurs when
`ibis.options.verbose` is `True`.
"""
from ibis import util

util.log(sql)

def sql(
self,
query: str,
schema: SupportsSchema | None = None,
dialect: str | None = None,
) -> ir.Table:
query = self._transpile_sql(query, dialect=dialect)
if schema is None:
schema = self._get_schema_using_query(query)
return ops.SQLQueryResult(query, ibis.schema(schema), self).to_expr()

@abc.abstractmethod
def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
"""Return the metadata of a SQL query."""

def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a backend-specific SQL string."""
return sch.Schema.from_tuples(self._metadata(query))

def create_view(
self,
name: str,
obj: ir.Table,
*,
database: str | None = None,
schema: str | None = None,
overwrite: bool = False,
) -> ir.Table:
src = sge.Create(
this=sg.table(
name, db=schema, catalog=database, quoted=self.compiler.quoted
),
kind="VIEW",
replace=overwrite,
expression=self.compile(obj),
)
self._register_in_memory_tables(obj)
with self._safe_raw_sql(src):
pass
return self.table(name, database=database)

def _register_in_memory_tables(self, expr: ir.Expr) -> None:
for memtable in expr.op().find(ops.InMemoryTable):
self._register_in_memory_table(memtable)

def drop_view(
self,
name: str,
*,
database: str | None = None,
schema: str | None = None,
force: bool = False,
) -> None:
src = sge.Drop(
this=sg.table(
name, db=schema, catalog=database, quoted=self.compiler.quoted
),
kind="VIEW",
exists=force,
)
with self._safe_raw_sql(src):
pass

def _get_temp_view_definition(self, name: str, definition: str) -> str:
return sge.Create(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
kind="VIEW",
expression=definition,
replace=True,
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
)

def _create_temp_view(self, table_name, source):
if table_name not in self._temp_views and table_name in self.list_tables():
raise ValueError(
f"{table_name} already exists as a non-temporary table or view"
)

with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)):
pass

self._temp_views.add(table_name)
self._register_temp_view_cleanup(table_name)

def _register_temp_view_cleanup(self, name: str) -> None:
"""Register a clean up function for a temporary view.
No-op by default.
Parameters
----------
name
The temporary view to register for clean up.
"""

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, op):
self.drop_table(op.name)

def execute(
self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any
) -> Any:
"""Execute an expression."""

self._run_pre_execute_hooks(expr)
table = expr.as_table()
sql = self.compile(table, limit=limit, **kwargs)

schema = table.schema()
self._log(sql)

with self._safe_raw_sql(sql) as cur:
result = self.fetch_from_cursor(cur, schema)
return expr.__pandas_result__(result)

def drop_table(
self,
name: str,
database: str | None = None,
schema: str | None = None,
force: bool = False,
) -> None:
drop_stmt = sg.exp.Drop(
kind="TABLE",
this=sg.table(
name, db=schema, catalog=database, quoted=self.compiler.quoted
),
exists=force,
)
with self._safe_raw_sql(drop_stmt):
pass
Loading

0 comments on commit ca95204

Please sign in to comment.