-
Notifications
You must be signed in to change notification settings - Fork 609
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(duckdb): initial cut of sqlglot DuckDB compiler
it's alive! tests run (and fail) chore(duckdb): naive port of clickhouse compiler fix(duckdb): hacky fix for output shape feat(duckdb): bitwise ops (most of them) feat(duckdb): handle pandas dtype mapping in execute feat(duckdb): handle decimal types feat(duckdb): add euler's number test(duckdb): remove duckdb from alchemycon feat(duckdb): get _most_ of string ops working still some failures in re_exract feat(duckdb): add hash feat(duckdb): add CAST feat(duckdb): add cot and strright chore(duckdb): mark all the targets that still need attention (at least) feat(duckdb): combine binary bitwise ops chore(datestuff): some datetime ops feat(duckdb): add levenshtein, use op.dtype instead of output_dtype feat(duckdb): add blank list_schemas, use old current_database for now feat(duckdb): basic interval ops feat(duckdb): timestamp and temporal ops feat(duckdb): use pyarrow for fetching execute results feat(duckdb): handle interval casts, broken for columns feat(duckdb): shove literal handling up top feat(duckdb): more timestamp ops feat(duckdb): back to pandas output in execute feat(duckdb): timezone handling in cast feat(duckdb): ms and us epoch timestamp support chore(duckdb): misc cleanup feat(duckdb): initial create table feat(duckdb): add _from_url feat(duckdb): add read_parquet feat(duckdb): add persistent cache fix(duckdb): actually insert data if present in create_table feat(duckdb): use duckdb API read_parquet feat(duckdb): add read_csv This, frustratingly, cannot use the Python API for `read_csv` since that does not support list of files, for some reason. fix(duckdb): dont fully qualify the table names chore(duckdb): cleanup chore(duckdb): mark broken test broken fix(duckdb): fix read_parquet so it works feat(duckdb): add to_pyarrow, to_pyarrow_batches, sql() feat(duckdb): null checking feat(duckdb): translate uints fix(duckdb): fix file outputs and torch output fix(duckdb): add rest of integer types fix(duckdb): ops.InValues feat(duckdb): use sqlglot expressions (maybe a big mistake) fix(duckdb): don't stringify strings feat(duckdb): use sqlglot expr instead of strings for count fix(duckdb): fix isin fix(duckdb): fix some agg variance functions fix(duckdb): for logical equals, use sqlglot not operator fix(duckdb): struct not tuple for struct type
- Loading branch information
Showing
326 changed files
with
10,943 additions
and
6,146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,245 @@ | ||
from __future__ import annotations | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Any, Callable | ||
import abc | ||
from typing import TYPE_CHECKING, Any, ClassVar | ||
|
||
import sqlglot as sg | ||
import sqlglot.expressions as sge | ||
|
||
if TYPE_CHECKING: | ||
import ibis.expr.datatypes as dt | ||
from ibis.backends.base.sqlglot.datatypes import SqlglotType | ||
|
||
|
||
class AggGen: | ||
__slots__ = ("aggfunc",) | ||
|
||
def __init__(self, *, aggfunc: Callable) -> None: | ||
self.aggfunc = aggfunc | ||
|
||
def __getattr__(self, name: str) -> partial: | ||
return partial(self.aggfunc, name) | ||
|
||
def __getitem__(self, key: str) -> partial: | ||
return getattr(self, key) | ||
|
||
|
||
def _func(name: str, *args: Any, **kwargs: Any): | ||
return sg.func(name, *map(sg.exp.convert, args), **kwargs) | ||
|
||
|
||
class FuncGen: | ||
__slots__ = () | ||
|
||
def __getattr__(self, name: str) -> partial: | ||
return partial(_func, name) | ||
|
||
def __getitem__(self, key: str) -> partial: | ||
return getattr(self, key) | ||
|
||
def array(self, *args): | ||
return sg.exp.Array.from_arg_list(list(map(sg.exp.convert, args))) | ||
|
||
def tuple(self, *args): | ||
return sg.func("tuple", *map(sg.exp.convert, args)) | ||
|
||
def exists(self, query): | ||
return sg.exp.Exists(this=query) | ||
|
||
def concat(self, *args): | ||
return sg.exp.Concat(expressions=list(map(sg.exp.convert, args))) | ||
import ibis | ||
import ibis.expr.operations as ops | ||
import ibis.expr.schema as sch | ||
from ibis.backends.base import BaseBackend | ||
from ibis.backends.base.sqlglot.compiler import STAR | ||
|
||
def map(self, keys, values): | ||
return sg.exp.Map(keys=keys, values=values) | ||
|
||
|
||
class ColGen: | ||
__slots__ = () | ||
|
||
def __getattr__(self, name: str) -> sg.exp.Column: | ||
return sg.column(name) | ||
|
||
def __getitem__(self, key: str) -> sg.exp.Column: | ||
return sg.column(key) | ||
|
||
|
||
def paren(expr): | ||
"""Wrap a sqlglot expression in parentheses.""" | ||
return sg.exp.Paren(this=expr) | ||
|
||
|
||
def parenthesize(op, arg): | ||
import ibis.expr.operations as ops | ||
|
||
if isinstance(op, (ops.Binary, ops.Unary)): | ||
return paren(arg) | ||
# function calls don't need parens | ||
return arg | ||
|
||
|
||
def interval(value, *, unit): | ||
return sg.exp.Interval(this=sg.exp.convert(value), unit=sg.exp.var(unit)) | ||
|
||
|
||
C = ColGen() | ||
F = FuncGen() | ||
NULL = sg.exp.Null() | ||
FALSE = sg.exp.false() | ||
TRUE = sg.exp.true() | ||
STAR = sg.exp.Star() | ||
|
||
|
||
def make_cast( | ||
converter: SqlglotType, | ||
) -> Callable[[sg.exp.Expression, dt.DataType], sg.exp.Cast]: | ||
def cast(arg: sg.exp.Expression, to: dt.DataType) -> sg.exp.Cast: | ||
return sg.cast(arg, to=converter.from_ibis(to)) | ||
if TYPE_CHECKING: | ||
from collections.abc import Iterator | ||
|
||
return cast | ||
import ibis.expr.datatypes as dt | ||
import ibis.expr.types as ir | ||
from ibis.backends.base.sqlglot.compiler import SQLGlotCompiler | ||
from ibis.common.typing import SupportsSchema | ||
|
||
|
||
class SQLGlotBackend(BaseBackend): | ||
compiler: ClassVar[SQLGlotCompiler] | ||
name: ClassVar[str] | ||
|
||
@classmethod | ||
def has_operation(cls, operation: type[ops.Value]) -> bool: | ||
# singledispatchmethod overrides `__get__` so we can't directly access | ||
# the dispatcher | ||
dispatcher = cls.compiler.visit_node.register.__self__.dispatcher | ||
return dispatcher.dispatch(operation) is not dispatcher.dispatch(object) | ||
|
||
def _transform( | ||
self, sql: sge.Expression, table_expr: ir.TableExpr | ||
) -> sge.Expression: | ||
return sql | ||
|
||
def table( | ||
self, name: str, schema: str | None = None, database: str | None = None | ||
) -> ir.Table: | ||
"""Construct a table expression. | ||
Parameters | ||
---------- | ||
name | ||
Table name | ||
schema | ||
Schema name | ||
database | ||
Database name | ||
Returns | ||
------- | ||
Table | ||
Table expression | ||
""" | ||
table_schema = self.get_schema(name, schema=schema, database=database) | ||
return ops.DatabaseTable( | ||
name, | ||
schema=table_schema, | ||
source=self, | ||
namespace=ops.Namespace(database=database, schema=schema), | ||
).to_expr() | ||
|
||
def _to_sqlglot( | ||
self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any | ||
): | ||
"""Compile an Ibis expression to a sqlglot object.""" | ||
table_expr = expr.as_table() | ||
|
||
if limit == "default": | ||
limit = ibis.options.sql.default_limit | ||
if limit is not None: | ||
table_expr = table_expr.limit(limit) | ||
|
||
if params is None: | ||
params = {} | ||
|
||
sql = self.compiler.translate(table_expr.op(), params=params) | ||
assert not isinstance(sql, sge.Subquery) | ||
|
||
if isinstance(sql, sge.Table): | ||
sql = sg.select(STAR).from_(sql) | ||
|
||
assert not isinstance(sql, sge.Subquery) | ||
return [self._transform(sql, table_expr)] | ||
|
||
def compile( | ||
self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any | ||
): | ||
"""Compile an Ibis expression to a ClickHouse SQL string.""" | ||
queries = self._to_sqlglot(expr, limit=limit, params=params, **kwargs) | ||
|
||
return ";\n\n".join( | ||
query.sql(dialect=self.name, pretty=True) for query in queries | ||
) | ||
|
||
def _to_sql(self, expr: ir.Expr, **kwargs) -> str: | ||
return self.compile(expr, **kwargs) | ||
|
||
def _log(self, sql: str) -> None: | ||
"""Log `sql`. | ||
This method can be implemented by subclasses. Logging occurs when | ||
`ibis.options.verbose` is `True`. | ||
""" | ||
from ibis import util | ||
|
||
util.log(sql) | ||
|
||
def sql( | ||
self, | ||
query: str, | ||
schema: SupportsSchema | None = None, | ||
dialect: str | None = None, | ||
) -> ir.Table: | ||
query = self._transpile_sql(query, dialect=dialect) | ||
if schema is None: | ||
schema = self._get_schema_using_query(query) | ||
return ops.SQLQueryResult(query, ibis.schema(schema), self).to_expr() | ||
|
||
@abc.abstractmethod | ||
def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: | ||
"""Return the metadata of a SQL query.""" | ||
|
||
def _get_schema_using_query(self, query: str) -> sch.Schema: | ||
"""Return an ibis Schema from a backend-specific SQL string.""" | ||
return sch.Schema.from_tuples(self._metadata(query)) | ||
|
||
def create_view( | ||
self, | ||
name: str, | ||
obj: ir.Table, | ||
*, | ||
database: str | None = None, | ||
schema: str | None = None, | ||
overwrite: bool = False, | ||
) -> ir.Table: | ||
src = sge.Create( | ||
this=sg.table( | ||
name, db=schema, catalog=database, quoted=self.compiler.quoted | ||
), | ||
kind="VIEW", | ||
replace=overwrite, | ||
expression=self.compile(obj), | ||
) | ||
self._register_in_memory_tables(obj) | ||
with self._safe_raw_sql(src): | ||
pass | ||
return self.table(name, database=database) | ||
|
||
def _register_in_memory_tables(self, expr: ir.Expr) -> None: | ||
for memtable in expr.op().find(ops.InMemoryTable): | ||
self._register_in_memory_table(memtable) | ||
|
||
def drop_view( | ||
self, | ||
name: str, | ||
*, | ||
database: str | None = None, | ||
schema: str | None = None, | ||
force: bool = False, | ||
) -> None: | ||
src = sge.Drop( | ||
this=sg.table( | ||
name, db=schema, catalog=database, quoted=self.compiler.quoted | ||
), | ||
kind="VIEW", | ||
exists=force, | ||
) | ||
with self._safe_raw_sql(src): | ||
pass | ||
|
||
def _get_temp_view_definition(self, name: str, definition: str) -> str: | ||
return sge.Create( | ||
this=sg.to_identifier(name, quoted=self.compiler.quoted), | ||
kind="VIEW", | ||
expression=definition, | ||
replace=True, | ||
properties=sge.Properties(expressions=[sge.TemporaryProperty()]), | ||
) | ||
|
||
def _create_temp_view(self, table_name, source): | ||
if table_name not in self._temp_views and table_name in self.list_tables(): | ||
raise ValueError( | ||
f"{table_name} already exists as a non-temporary table or view" | ||
) | ||
|
||
with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)): | ||
pass | ||
|
||
self._temp_views.add(table_name) | ||
self._register_temp_view_cleanup(table_name) | ||
|
||
def _register_temp_view_cleanup(self, name: str) -> None: | ||
"""Register a clean up function for a temporary view. | ||
No-op by default. | ||
Parameters | ||
---------- | ||
name | ||
The temporary view to register for clean up. | ||
""" | ||
|
||
def _load_into_cache(self, name, expr): | ||
self.create_table(name, expr, schema=expr.schema(), temp=True) | ||
|
||
def _clean_up_cached_table(self, op): | ||
self.drop_table(op.name) | ||
|
||
def execute( | ||
self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any | ||
) -> Any: | ||
"""Execute an expression.""" | ||
|
||
self._run_pre_execute_hooks(expr) | ||
table = expr.as_table() | ||
sql = self.compile(table, limit=limit, **kwargs) | ||
|
||
schema = table.schema() | ||
self._log(sql) | ||
|
||
with self._safe_raw_sql(sql) as cur: | ||
result = self.fetch_from_cursor(cur, schema) | ||
return expr.__pandas_result__(result) | ||
|
||
def drop_table( | ||
self, | ||
name: str, | ||
database: str | None = None, | ||
schema: str | None = None, | ||
force: bool = False, | ||
) -> None: | ||
drop_stmt = sg.exp.Drop( | ||
kind="TABLE", | ||
this=sg.table( | ||
name, db=schema, catalog=database, quoted=self.compiler.quoted | ||
), | ||
exists=force, | ||
) | ||
with self._safe_raw_sql(drop_stmt): | ||
pass |
Oops, something went wrong.