diff --git a/.gitignore b/.gitignore index bb1ea1c2ed59..e492e03661a9 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,6 @@ ibis/examples/descriptions # automatically generated odbc file for ci ci/odbc/odbc.ini *-citibike-tripdata.tar.xz + +# data downloaded by the geospatial tutorial +docs/posts/ibis-duckdb-geospatial/nyc_data.db.wal diff --git a/docs/backends/impala.qmd b/docs/backends/impala.qmd index 80d96d91a5d8..c6e4b80f1acf 100644 --- a/docs/backends/impala.qmd +++ b/docs/backends/impala.qmd @@ -202,7 +202,7 @@ table or database. ```{python} #| echo: false #| output: asis -render_methods(get_object("ibis.backends.base.sql", "BaseSQLBackend"), "table") +render_methods(get_object("ibis.backends.base.sqlglot", "SQLGlotBackend"), "table") ``` The client's `table` method allows you to create an Ibis table diff --git a/ibis/backends/base/sql/__init__.py b/ibis/backends/base/sql/__init__.py deleted file mode 100644 index c592c112c992..000000000000 --- a/ibis/backends/base/sql/__init__.py +++ /dev/null @@ -1,409 +0,0 @@ -from __future__ import annotations - -import abc -import contextlib -from functools import lru_cache -from typing import TYPE_CHECKING, Any, Optional -from urllib.parse import parse_qs, urlparse - -import ibis.common.exceptions as exc -import ibis.expr.operations as ops -import ibis.expr.schema as sch -import ibis.expr.types as ir -from ibis import util -from ibis.backends.base import BaseBackend -from ibis.backends.base.sql.compiler import Compiler - -if TYPE_CHECKING: - from collections.abc import Iterable, Mapping - - import pandas as pd - import pyarrow as pa - -__all__ = ["BaseSQLBackend"] - - -class BaseSQLBackend(BaseBackend): - """Base backend class for backends that compile to SQL.""" - - compiler = Compiler - - def _from_url(self, url: str, **kwargs): - """Connect to a backend using a URL `url`. - - Parameters - ---------- - url - URL with which to connect to a backend. - kwargs - Additional keyword arguments - - Returns - ------- - BaseBackend - A backend instance - - """ - url = urlparse(url) - database = url.path[1:] - query_params = parse_qs(url.query) - kwargs = { - "user": url.username, - "password": url.password or "", - "host": url.hostname, - "database": database or "", - } | kwargs - - for name, value in query_params.items(): - if len(value) > 1: - kwargs[name] = value - elif len(value) == 1: - kwargs[name] = value[0] - else: - raise exc.IbisError(f"Invalid URL parameter: {name}") - - return self.connect(**kwargs) - - def table(self, name: str, database: str | None = None) -> ir.Table: - """Construct a table expression. - - Parameters - ---------- - name - Table name - database - Database name - - Returns - ------- - Table - Table expression - - """ - if database is not None and not isinstance(database, str): - raise exc.IbisTypeError( - f"`database` must be a string; got {type(database)}" - ) - qualified_name = self._fully_qualified_name(name, database) - schema = self.get_schema(qualified_name) - node = ops.DatabaseTable( - name, schema, self, namespace=ops.Namespace(database=database) - ) - return node.to_expr() - - def _fully_qualified_name(self, name, database): - # XXX - return name - - def sql( - self, query: str, schema: sch.Schema | None = None, dialect: str | None = None - ) -> ir.Table: - """Convert a SQL query to an Ibis table expression. - - Parameters - ---------- - query - SQL string - schema - The expected schema for this query. If not provided, will be - inferred automatically if possible. - dialect - Optional string indicating the dialect of `query`. The default - value of `None` will use the backend's native dialect. - - Returns - ------- - Table - Table expression - - """ - query = self._transpile_sql(query, dialect=dialect) - if schema is None: - schema = self._get_schema_using_query(query) - else: - schema = sch.schema(schema) - return ops.SQLQueryResult(query, schema, self).to_expr() - - def _get_schema_using_query(self, query): - raise NotImplementedError(f"Backend {self.name} does not support .sql()") - - def raw_sql(self, query: str): - """Execute a query string and return the cursor used for execution. - - ::: {.callout-tip} - ## Consider using [`.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) instead - - If your query is a `SELECT` statement you can use the - [backend `.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) method to avoid - having to manually release the cursor returned from this method. - - ::: {.callout-warning} - ## The cursor returned from this method must be **manually released** - - You **do not** need to call `.close()` on the cursor when running DDL - or DML statements like `CREATE`, `INSERT` or `DROP`, only when using - `SELECT` statements. - - To release a cursor, call the `close` method on the returned cursor - object. - - You can close the cursor by explicitly calling its `close` method: - - ```python - cursor = con.raw_sql("SELECT ...") - cursor.close() - ``` - - Or you can use a context manager: - - ```python - with con.raw_sql("SELECT ...") as cursor: - ... - ``` - ::: - - ::: - - Parameters - ---------- - query - SQL query string - - Examples - -------- - >>> con = ibis.connect("duckdb://") - >>> with con.raw_sql("SELECT 1") as cursor: - ... result = cursor.fetchall() - >>> result - [(1,)] - >>> cursor.closed - True - - """ - return self.con.execute(query) - - @contextlib.contextmanager - def _safe_raw_sql(self, *args, **kwargs): - yield self.raw_sql(*args, **kwargs) - - def _cursor_batches( - self, - expr: ir.Expr, - params: Mapping[ir.Scalar, Any] | None = None, - limit: int | str | None = None, - chunk_size: int = 1_000_000, - ) -> Iterable[list]: - self._run_pre_execute_hooks(expr) - query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) - sql = query_ast.compile() - - with self._safe_raw_sql(sql) as cursor: - while batch := cursor.fetchmany(chunk_size): - yield batch - - @util.experimental - def to_pyarrow_batches( - self, - expr: ir.Expr, - *, - params: Mapping[ir.Scalar, Any] | None = None, - limit: int | str | None = None, - chunk_size: int = 1_000_000, - **_: Any, - ) -> pa.ipc.RecordBatchReader: - """Execute expression and return an iterator of pyarrow record batches. - - This method is eager and will execute the associated expression - immediately. - - Parameters - ---------- - expr - Ibis expression to export to pyarrow - limit - An integer to effect a specific row limit. A value of `None` means - "no limit". The default is in `ibis/config.py`. - params - Mapping of scalar parameter expressions to value. - chunk_size - Maximum number of rows in each returned record batch. - - Returns - ------- - RecordBatchReader - Collection of pyarrow `RecordBatch`s. - - """ - pa = self._import_pyarrow() - - schema = expr.as_table().schema() - array_type = schema.as_struct().to_pyarrow() - arrays = ( - pa.array(map(tuple, batch), type=array_type) - for batch in self._cursor_batches( - expr, params=params, limit=limit, chunk_size=chunk_size - ) - ) - batches = map(pa.RecordBatch.from_struct_array, arrays) - - return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batches) - - def _register_udfs(self, expr: ir.Expr) -> None: - """Return an iterator of DDL strings, once for each UDFs contained within `expr`.""" - if self.supports_python_udfs: - raise NotImplementedError(self.name) - - def _gen_udf_name(self, name: str, schema: Optional[str]) -> str: - return ".".join(filter(None, (schema, name))) - - def _gen_udf_rule(self, op: ops.ScalarUDF): - @self.add_operation(type(op)) - def _(t, op): - func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__) - return f"{func}({', '.join(map(t.translate, op.args))})" - - def _gen_udaf_rule(self, op: ops.AggUDF): - from ibis import NA - - @self.add_operation(type(op)) - def _(t, op): - func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__) - args = ", ".join( - t.translate( - ops.IfElse(where, arg, NA) - if (where := op.where) is not None - else arg - ) - for name, arg in zip(op.argnames, op.args) - if name != "where" - ) - return f"{func}({args})" - - def _define_udf_translation_rules(self, expr): - for udf_node in expr.op().find(ops.ScalarUDF): - udf_node_type = type(udf_node) - - if udf_node_type not in self.compiler.translator_class._registry: - self._gen_udf_rule(udf_node) - - for udf_node in expr.op().find(ops.AggUDF): - udf_node_type = type(udf_node) - - if udf_node_type not in self.compiler.translator_class._registry: - self._gen_udaf_rule(udf_node) - - def execute( - self, - expr: ir.Expr, - params: Mapping[ir.Scalar, Any] | None = None, - limit: str = "default", - **kwargs: Any, - ): - """Compile and execute an Ibis expression. - - Compile and execute Ibis expression using this backend client - interface, returning results in-memory in the appropriate object type - - Parameters - ---------- - expr - Ibis expression - limit - For expressions yielding result sets; retrieve at most this number - of values/rows. Overrides any limit already set on the expression. - params - Named unbound parameters - kwargs - Backend specific arguments. For example, the clickhouse backend - uses this to receive `external_tables` as a dictionary of pandas - DataFrames. - - Returns - ------- - DataFrame | Series | Scalar - * `Table`: pandas.DataFrame - * `Column`: pandas.Series - * `Scalar`: Python scalar value - - """ - # TODO Reconsider having `kwargs` here. It's needed to support - # `external_tables` in clickhouse, but better to deprecate that - # feature than all this magic. - # we don't want to pass `timecontext` to `raw_sql` - self._run_pre_execute_hooks(expr) - - kwargs.pop("timecontext", None) - query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) - sql = query_ast.compile() - self._log(sql) - - schema = expr.as_table().schema() - - with self._safe_raw_sql(sql, **kwargs) as cursor: - result = self.fetch_from_cursor(cursor, schema) - - return expr.__pandas_result__(result) - - def _register_in_memory_table(self, _: ops.InMemoryTable) -> None: - raise NotImplementedError(self.name) - - def _register_in_memory_tables(self, expr: ir.Expr) -> None: - if self.compiler.cheap_in_memory_tables: - for memtable in expr.op().find(ops.InMemoryTable): - self._register_in_memory_table(memtable) - - @abc.abstractmethod - def fetch_from_cursor(self, cursor, schema): - """Fetch data from cursor.""" - - def _log(self, sql: str) -> None: - """Log the SQL, usually to the standard output. - - This method can be implemented by subclasses. The logging - happens when `ibis.options.verbose` is `True`. - """ - util.log(sql) - - def compile( - self, - expr: ir.Expr, - limit: str | None = None, - params: Mapping[ir.Expr, Any] | None = None, - timecontext: tuple[pd.Timestamp, pd.Timestamp] | None = None, - ) -> Any: - """Compile an Ibis expression. - - Parameters - ---------- - expr - Ibis expression - limit - For expressions yielding result sets; retrieve at most this number - of values/rows. Overrides any limit already set on the expression. - params - Named unbound parameters - timecontext - Additional information about data source time boundaries - - Returns - ------- - Any - The output of compilation. The type of this value depends on the - backend. - - """ - self._define_udf_translation_rules(expr) - return self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile() - - def _to_sql(self, expr: ir.Expr, **kwargs) -> str: - return str(self.compile(expr, **kwargs)) - - @classmethod - @lru_cache - def _get_operations(cls): - translator = cls.compiler.translator_class - return translator._registry.keys() | translator._rewrites.keys() - - @classmethod - def has_operation(cls, operation: type[ops.Value]) -> bool: - return operation in cls._get_operations() diff --git a/ibis/backends/base/sql/compiler/__init__.py b/ibis/backends/base/sql/compiler/__init__.py deleted file mode 100644 index 569dfdbb2670..000000000000 --- a/ibis/backends/base/sql/compiler/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from ibis.backends.base.sql.compiler.base import DDL, DML -from ibis.backends.base.sql.compiler.query_builder import ( - Compiler, - Difference, - Intersection, - Select, - SelectBuilder, - TableSetFormatter, - Union, -) -from ibis.backends.base.sql.compiler.translator import ExprTranslator, QueryContext - -__all__ = ( - "Compiler", - "Select", - "SelectBuilder", - "Union", - "Intersection", - "Difference", - "TableSetFormatter", - "ExprTranslator", - "QueryContext", - "DML", - "DDL", -) diff --git a/ibis/backends/base/sql/compiler/base.py b/ibis/backends/base/sql/compiler/base.py deleted file mode 100644 index aa191499ba1b..000000000000 --- a/ibis/backends/base/sql/compiler/base.py +++ /dev/null @@ -1,114 +0,0 @@ -from __future__ import annotations - -from itertools import chain - -import toolz - -import ibis.expr.analysis as an -import ibis.expr.operations as ops -from ibis import util - - -class DML: - def compile(self): - raise NotImplementedError() - - -class DDL: - def compile(self): - raise NotImplementedError() - - -class QueryAST: - __slots__ = "context", "dml", "setup_queries", "teardown_queries" - - def __init__(self, context, dml, setup_queries=None, teardown_queries=None): - self.context = context - self.dml = dml - self.setup_queries = setup_queries - self.teardown_queries = teardown_queries - - @property - def queries(self): - return [self.dml] - - def compile(self): - compiled_setup_queries = [q.compile() for q in self.setup_queries] - compiled_queries = [q.compile() for q in self.queries] - compiled_teardown_queries = [q.compile() for q in self.teardown_queries] - return self.context.collapse( - list( - chain( - compiled_setup_queries, - compiled_queries, - compiled_teardown_queries, - ) - ) - ) - - -class SetOp(DML): - def __init__(self, tables, node, context, distincts): - assert isinstance(node, ops.Node) - assert all(isinstance(table, ops.Node) for table in tables) - self.context = context - self.tables = tables - self.table_set = node - self.distincts = distincts - self.filters = [] - - @classmethod - def keyword(cls, distinct): - return cls._keyword + (not distinct) * " ALL" - - def _get_keyword_list(self): - return map(self.keyword, self.distincts) - - def _extract_subqueries(self): - # extract any subquery to avoid generating incorrect sql when at least - # one of the set operands is invalid outside of being a subquery - # - # for example: SELECT * FROM t ORDER BY x UNION ... - self.subqueries = an.find_subqueries( - [self.table_set, *self.filters], min_dependents=1 - ) - for subquery in self.subqueries: - self.context.set_extracted(subquery) - - def format_subqueries(self): - context = self.context - subqueries = self.subqueries - - return ",\n".join( - "{} AS (\n{}\n)".format( - context.get_ref(expr), - util.indent(context.get_compiled_expr(expr), 2), - ) - for expr in subqueries - ) - - def format_relation(self, expr): - ref = self.context.get_ref(expr) - if ref is not None: - return f"SELECT *\nFROM {ref}" - return self.context.get_compiled_expr(expr) - - def compile(self): - self._extract_subqueries() - - extracted = self.format_subqueries() - - buf = [] - - if extracted: - buf.append(f"WITH {extracted}") - - buf.extend( - toolz.interleave( - ( - map(self.format_relation, self.tables), - self._get_keyword_list(), - ) - ) - ) - return "\n".join(buf) diff --git a/ibis/backends/base/sql/compiler/query_builder.py b/ibis/backends/base/sql/compiler/query_builder.py deleted file mode 100644 index 44f5a728275f..000000000000 --- a/ibis/backends/base/sql/compiler/query_builder.py +++ /dev/null @@ -1,655 +0,0 @@ -from __future__ import annotations - -from io import StringIO -from typing import TYPE_CHECKING - -import sqlglot as sg -import toolz - -import ibis.common.exceptions as com -import ibis.expr.operations as ops -import ibis.expr.types as ir -from ibis import util -from ibis.backends.base.sql.compiler.base import DML, QueryAST, SetOp -from ibis.backends.base.sql.compiler.select_builder import SelectBuilder, _LimitSpec -from ibis.backends.base.sql.compiler.translator import ExprTranslator, QueryContext -from ibis.backends.base.sql.registry import quote_identifier -from ibis.common.grounds import Comparable -from ibis.config import options -from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna - -if TYPE_CHECKING: - from collections.abc import Iterable - - -class TableSetFormatter: - def __init__(self, parent, node, indent=2): - # `parent` is a `Select` instance, not a `TableSetFormatter` - self.parent = parent - self.context = parent.context - self.node = node - self.indent = indent - - self.join_tables = [] - self.join_types = [] - self.join_predicates = [] - - def _translate(self, expr): - return self.parent._translate(expr) - - # TODO(kszucs): could use lin.traverse here - def _walk_join_tree(self, op): - if util.all_of([op.left, op.right], ops.Join): - raise NotImplementedError("Do not support joins between joins yet") - - jname = self._get_join_type(op) - - # Read off tables and join predicates left-to-right in - # depth-first order - if isinstance(op.left, ops.Join): - self._walk_join_tree(op.left) - self.join_tables.append(self._format_table(op.right)) - elif isinstance(op.right, ops.Join): - self.join_tables.append(self._format_table(op.left)) - self._walk_join_tree(op.right) - else: - # Both tables - self.join_tables.append(self._format_table(op.left)) - self.join_tables.append(self._format_table(op.right)) - - self.join_types.append(jname) - self.join_predicates.append(op.predicates) - - def _get_join_type(self, op): - return self._join_names[type(op)] - - def _quote_identifier(self, name): - return quote_identifier(name) - - def _format_in_memory_table(self, op): - if self.context.compiler.cheap_in_memory_tables: - return op.name - - names = op.schema.names - raw_rows = [] - for row in op.data.to_frame().itertuples(index=False): - raw_row = [] - for val, name in zip(row, names): - lit = ops.Literal(val, dtype=op.schema[name]) - raw_row.append( - f"{self._translate(lit)} AS {self._quote_identifier(name)}" - ) - raw_rows.append(", ".join(raw_row)) - - if self.context.compiler.support_values_syntax_in_select: - rows = ", ".join(f"({raw_row})" for raw_row in raw_rows) - return f"(VALUES {rows})" - else: - rows = " UNION ALL ".join(f"(SELECT {raw_row})" for raw_row in raw_rows) - return f"({rows})" - - def _format_table(self, op): - # TODO: This could probably go in a class and be significantly nicer - ctx = self.context - - orig_op = op - if isinstance(op, (ops.SelfReference, ops.Sample)): - op = op.table - - alias = ctx.get_ref(orig_op) - - if isinstance(op, ops.InMemoryTable): - result = self._format_in_memory_table(op) - elif isinstance(op, ops.PhysicalTable): - # TODO(kszucs): add a mandatory `name` field to the base - # PhyisicalTable instead of the child classes, this should prevent - # this error scenario - if (name := op.name) is None: - raise com.RelationError(f"Table did not have a name: {op!r}") - - namespace = getattr(op, "namespace", None) - catalog = getattr(namespace, "database", None) - db = getattr(namespace, "schema", None) - result = sg.table( - name, - db=db, - catalog=catalog, - quoted=self.parent.translator_class._quote_identifiers, - ).sql(dialect=self.parent.translator_class._dialect_name) - elif ctx.is_extracted(op): - if isinstance(orig_op, ops.SelfReference): - result = ctx.get_ref(op) - else: - result = alias - else: - subquery = ctx.get_compiled_expr(op) - result = f"(\n{util.indent(subquery, self.indent)}\n)" - - if result != alias: - result = f"{result} {alias}" - - if isinstance(orig_op, ops.Sample): - result = self._format_sample(orig_op, result) - - return result - - def _format_sample(self, op, table): - # Should never be hit in practice, as Sample operations should be rewritten - # before this point for all backends without TABLESAMPLE support - raise com.UnsupportedOperationError("`Table.sample` is not supported") - - def get_result(self): - # Got to unravel the join stack; the nesting order could be - # arbitrary, so we do a depth first search and push the join tokens - # and predicates onto a flat list, then format them - op = self.node - - if isinstance(op, ops.Join): - self._walk_join_tree(op) - else: - self.join_tables.append(self._format_table(op)) - - # TODO: Now actually format the things - buf = StringIO() - buf.write(self.join_tables[0]) - for jtype, table, preds in zip( - self.join_types, self.join_tables[1:], self.join_predicates - ): - buf.write("\n") - buf.write(util.indent(f"{jtype} {table}", self.indent)) - - fmt_preds = [] - npreds = len(preds) - for pred in preds: - new_pred = self._translate(pred) - if npreds > 1: - new_pred = f"({new_pred})" - fmt_preds.append(new_pred) - - if len(fmt_preds): - buf.write("\n") - - conj = " AND\n{}".format(" " * 3) - fmt_preds = util.indent("ON " + conj.join(fmt_preds), self.indent * 2) - buf.write(fmt_preds) - - return buf.getvalue() - - -class Select(DML, Comparable): - """A SELECT statement.""" - - def __init__( - self, - table_set, - select_set, - translator_class, - table_set_formatter_class, - context, - subqueries=None, - where=None, - group_by=None, - having=None, - order_by=None, - limit=None, - distinct=False, - indent=2, - parent_op=None, - ): - self.translator_class = translator_class - self.table_set_formatter_class = table_set_formatter_class - self.context = context - - self.select_set = select_set - self.table_set = table_set - self.distinct = distinct - - self.parent_op = parent_op - - self.where = where or [] - - # Group keys and post-predicates for aggregations - self.group_by = group_by or [] - self.having = having or [] - self.order_by = order_by or [] - - self.limit = limit - self.subqueries = subqueries or [] - - self.indent = indent - - def _translate(self, expr, named=False, permit_subquery=False, within_where=False): - translator = self.translator_class( - expr, - context=self.context, - named=named, - permit_subquery=permit_subquery, - within_where=within_where, - ) - return translator.get_result() - - def __equals__(self, other: Select) -> bool: - return self.limit == other.limit and self._all_exprs() == other._all_exprs() - - def _all_exprs(self): - return tuple( - *self.select_set, - self.table_set, - *self.where, - *self.group_by, - *self.having, - *self.order_by, - *self.subqueries, - ) - - def compile(self): - """Compile a query. - - This method isn't yet idempotent; calling multiple times may yield - unexpected results. - """ - # Can't tell if this is a hack or not. Revisit later - self.context.set_query(self) - - # If any subqueries, translate them and add to beginning of query as - # part of the WITH section - with_frag = self.format_subqueries() - - # SELECT - select_frag = self.format_select_set() - - # FROM, JOIN, UNION - from_frag = self.format_table_set() - - # WHERE - where_frag = self.format_where() - - # GROUP BY and HAVING - groupby_frag = self.format_group_by() - - # ORDER BY - order_frag = self.format_order_by() - - # LIMIT - limit_frag = self.format_limit() - - # Glue together the query fragments and return - query = "\n".join( - filter( - None, - [ - with_frag, - select_frag, - from_frag, - where_frag, - groupby_frag, - order_frag, - limit_frag, - ], - ) - ) - return query - - def format_subqueries(self): - if not self.subqueries: - return None - - context = self.context - - buf = [] - - for expr in self.subqueries: - formatted = util.indent(context.get_compiled_expr(expr), 2) - alias = context.get_ref(expr) - buf.append(f"{alias} AS (\n{formatted}\n)") - - return "WITH {}".format(",\n".join(buf)) - - def format_select_set(self): - # TODO: - context = self.context - formatted = [] - for node in self.select_set: - if isinstance(node, ops.Value): - expr_str = self._translate(node, named=True, permit_subquery=True) - elif isinstance(node, ops.TableNode): - alias = context.get_ref(node) - expr_str = f"{alias}.*" if alias else "*" - else: - raise TypeError(node) - formatted.append(expr_str) - - buf = StringIO() - line_length = 0 - max_length = 70 - tokens = 0 - for i, val in enumerate(formatted): - # always line-break for multi-line expressions - if val.count("\n"): - if i: - buf.write(",") - buf.write("\n") - indented = util.indent(val, self.indent) - buf.write(indented) - - # set length of last line - line_length = len(indented.split("\n")[-1]) - tokens = 1 - elif tokens > 0 and line_length and len(val) + line_length > max_length: - # There is an expr, and adding this new one will make the line - # too long - buf.write(",\n ") if i else buf.write("\n") - buf.write(val) - line_length = len(val) + 7 - tokens = 1 - else: - if i: - buf.write(",") - buf.write(" ") - buf.write(val) - tokens += 1 - line_length += len(val) + 2 - - if self.distinct: - select_key = "SELECT DISTINCT" - else: - select_key = "SELECT" - - return f"{select_key}{buf.getvalue()}" - - def format_table_set(self): - if self.table_set is None: - return None - - fragment = "FROM " - - helper = self.table_set_formatter_class(self, self.table_set) - fragment += helper.get_result() - - return fragment - - def format_group_by(self): - if not len(self.group_by): - # There is no aggregation, nothing to see here - return None - - lines = [] - if len(self.group_by) > 0: - clause = "GROUP BY {}".format( - ", ".join([str(x + 1) for x in self.group_by]) - ) - lines.append(clause) - - if len(self.having) > 0: - trans_exprs = [] - for expr in self.having: - translated = self._translate(expr) - trans_exprs.append(translated) - lines.append("HAVING {}".format(" AND ".join(trans_exprs))) - - return "\n".join(lines) - - def format_where(self): - if not self.where: - return None - - buf = StringIO() - buf.write("WHERE ") - fmt_preds = [] - npreds = len(self.where) - for pred in self.where: - new_pred = self._translate(pred, permit_subquery=True, within_where=True) - if npreds > 1: - new_pred = f"({new_pred})" - fmt_preds.append(new_pred) - - conj = " AND\n{}".format(" " * 6) - buf.write(conj.join(fmt_preds)) - return buf.getvalue() - - def format_order_by(self): - if not self.order_by: - return None - - buf = StringIO() - buf.write("ORDER BY ") - - formatted = [] - for key in self.order_by: - translated = self._translate(key.expr) - suffix = "ASC" if key.ascending else "DESC" - translated += f" {suffix}" - formatted.append(translated) - - buf.write(", ".join(formatted)) - return buf.getvalue() - - def format_limit(self): - if self.limit is None: - return None - - buf = StringIO() - - n = self.limit.n - - if n is None: - n = self.context.compiler.null_limit - elif not isinstance(n, int): - n = f"(SELECT {self._translate(n)} {self.format_table_set()})" - - if n is not None: - buf.write(f"LIMIT {n}") - - offset = self.limit.offset - - if not isinstance(offset, int): - offset = f"(SELECT {self._translate(offset)} {self.format_table_set()})" - - if offset != 0 and n != 0: - buf.write(f" OFFSET {offset}") - - return buf.getvalue() - - -class Union(SetOp): - _keyword = "UNION" - - -class Intersection(SetOp): - _keyword = "INTERSECT" - - -class Difference(SetOp): - _keyword = "EXCEPT" - - -def flatten_set_op(op) -> Iterable[ops.Table | bool]: - """Extract all union queries from `table`. - - Parameters - ---------- - op - Set operation to flatten - - Returns - ------- - Iterable[Table | bool] - Iterable of tables and `bool`s indicating `distinct`. - - """ - - if isinstance(op, ops.SetOp): - # For some reason mypy considers `op.left` and `op.right` - # of `Argument` type, and fails the validation. While in - # `flatten` types are the same, and it works - return toolz.concatv( - flatten_set_op(op.left), # type: ignore - [op.distinct], - flatten_set_op(op.right), # type: ignore - ) - return [op] - - -def flatten(op: ops.TableNode): - """Extract all intersection or difference queries from `table`. - - Parameters - ---------- - op - Table operation to flatten - - Returns - ------- - Iterable[Table | bool] - Iterable of tables and `bool`s indicating `distinct`. - - """ - return list(toolz.concatv(flatten_set_op(op.left), flatten_set_op(op.right))) - - -class Compiler: - translator_class = ExprTranslator - context_class = QueryContext - select_builder_class = SelectBuilder - table_set_formatter_class = TableSetFormatter - select_class = Select - union_class = Union - intersect_class = Intersection - difference_class = Difference - - cheap_in_memory_tables = False - support_values_syntax_in_select = True - null_limit = None - - rewrites = rewrite_fillna | rewrite_dropna - - @classmethod - def make_context(cls, params=None): - params = params or {} - - unaliased_params = {} - for expr, value in params.items(): - op = expr.op() - if isinstance(op, ops.Alias): - op = op.arg - unaliased_params[op] = value - - return cls.context_class(compiler=cls, params=unaliased_params) - - @classmethod - def to_ast(cls, node, context=None): - # TODO(kszucs): consider to support a single type only - if isinstance(node, ir.Expr): - node = node.op() - - if cls.rewrites: - node = node.replace(cls.rewrites) - - if context is None: - context = cls.make_context() - - # collect setup and teardown queries - setup_queries = cls._generate_setup_queries(node, context) - teardown_queries = cls._generate_teardown_queries(node, context) - - # TODO: any setup / teardown DDL statements will need to be done prior - # to building the result set-generating statements. - if isinstance(node, ops.Union): - query = cls._make_union(node, context) - elif isinstance(node, ops.Intersection): - query = cls._make_intersect(node, context) - elif isinstance(node, ops.Difference): - query = cls._make_difference(node, context) - else: - query = cls.select_builder_class().to_select( - select_class=cls.select_class, - table_set_formatter_class=cls.table_set_formatter_class, - node=node, - context=context, - translator_class=cls.translator_class, - ) - - return QueryAST( - context, - query, - setup_queries=setup_queries, - teardown_queries=teardown_queries, - ) - - @classmethod - def to_ast_ensure_limit(cls, expr, limit=None, params=None): - context = cls.make_context(params=params) - query_ast = cls.to_ast(expr, context) - - # note: limit can still be None at this point, if the global - # default_limit is None - for query in reversed(query_ast.queries): - if ( - isinstance(query, Select) - and not isinstance(expr, ir.Scalar) - and query.table_set is not None - ): - if query.limit is None: - if limit == "default": - query_limit = options.sql.default_limit - else: - query_limit = limit - if query_limit: - query.limit = _LimitSpec(query_limit, offset=0) - elif limit is not None and limit != "default": - query.limit = _LimitSpec(limit, query.limit.offset) - - return query_ast - - @classmethod - def to_sql(cls, node, context=None, params=None): - # TODO(kszucs): consider to support a single type only - if isinstance(node, ir.Expr): - node = node.op() - - assert isinstance(node, ops.Node) - - if context is None: - context = cls.make_context(params=params) - return cls.to_ast(node, context).queries[0].compile() - - @staticmethod - def _generate_setup_queries(expr, context): - return [] - - @staticmethod - def _generate_teardown_queries(expr, context): - return [] - - @staticmethod - def _make_set_op(cls, op, context): - # flatten unions so that we can codegen them all at once - set_op_info = list(flatten_set_op(op)) - - # since op is a union, we have at least 3 elements in union_info (left - # distinct right) and if there is more than a single union we have an - # additional two elements per union (distinct right) which means the - # total number of elements is at least 3 + (2 * number of unions - 1) - # and is therefore an odd number - npieces = len(set_op_info) - assert npieces >= 3 and npieces % 2 != 0, "Invalid set operation expression" - - # 1. every other object starting from 0 is a Table instance - # 2. every other object starting from 1 is a bool indicating the type - # of $set_op (distinct or not distinct) - table_exprs, distincts = set_op_info[::2], set_op_info[1::2] - return cls(table_exprs, op, distincts=distincts, context=context) - - @classmethod - def _make_union(cls, op, context): - return cls._make_set_op(cls.union_class, op, context) - - @classmethod - def _make_intersect(cls, op, context): - # flatten intersections so that we can codegen them all at once - return cls._make_set_op(cls.intersect_class, op, context) - - @classmethod - def _make_difference(cls, op, context): - # flatten differences so that we can codegen them all at once - return cls._make_set_op(cls.difference_class, op, context) diff --git a/ibis/backends/base/sql/compiler/select_builder.py b/ibis/backends/base/sql/compiler/select_builder.py deleted file mode 100644 index f1f9b73cc984..000000000000 --- a/ibis/backends/base/sql/compiler/select_builder.py +++ /dev/null @@ -1,266 +0,0 @@ -from __future__ import annotations - -from typing import NamedTuple - -import ibis.expr.analysis as an -import ibis.expr.operations as ops - - -class _LimitSpec(NamedTuple): - n: ops.Value | int | None - offset: ops.Value | int = 0 - - -class SelectBuilder: - """Transforms expression IR to a query pipeline. - - There will typically be a primary SELECT query, perhaps with some - subqueries and other DDL to ingest and tear down intermediate data sources. - - Walks the expression tree and catalogues distinct query units, - builds select statements (and other DDL types, where necessary), and - records relevant query unit aliases to be used when actually - generating SQL. - """ - - def to_select( - self, - select_class, - table_set_formatter_class, - node, - context, - translator_class, - ): - self.select_class = select_class - self.table_set_formatter_class = table_set_formatter_class - self.context = context - self.translator_class = translator_class - - self.op = node.to_expr().as_table().op() - assert isinstance(self.op, ops.Node), type(self.op) - - self.table_set = None - self.select_set = None - self.group_by = None - self.having = None - self.filters = [] - self.limit = None - self.order_by = [] - self.subqueries = [] - self.distinct = False - - select_query = self._build_result_query() - - self.queries = [select_query] - - return select_query - - def _build_result_query(self): - self._collect_elements() - self._analyze_subqueries() - self._populate_context() - - return self.select_class( - self.table_set, - list(self.select_set), - translator_class=self.translator_class, - table_set_formatter_class=self.table_set_formatter_class, - context=self.context, - subqueries=self.subqueries, - where=self.filters, - group_by=self.group_by, - having=self.having, - limit=self.limit, - order_by=self.order_by, - distinct=self.distinct, - parent_op=self.op, - ) - - def _populate_context(self): - # Populate aliases for the distinct relations used to output this - # select statement. - if self.table_set is not None: - self._make_table_aliases(self.table_set) - - # TODO(kszucs): should be rewritten using lin.traverse() - def _make_table_aliases(self, node): - ctx = self.context - - if isinstance(node, ops.JoinChain): - for arg in node.args: - if isinstance(arg, ops.Relation): - self._make_table_aliases(arg) - elif not ctx.is_extracted(node): - ctx.make_alias(node) - else: - # The compiler will apply a prefix only if the current context - # contains two or more table references. So, if we've extracted - # a subquery into a CTE, we need to propagate that reference - # down to child contexts so that they aren't missing any refs. - ctx.set_ref(node, ctx.top_context.get_ref(node)) - - # --------------------------------------------------------------------- - # Analysis of table set - - def _collect_elements(self): - # If expr is a Value, we must seek out the Tables that it - # references, build their ASTs, and mark them in our QueryContext - - # For now, we need to make the simplifying assumption that a value - # expression that is being translated only depends on a single table - # expression. - - if isinstance(self.op, ops.DummyTable): - self.select_set = list(self.op.values) - elif isinstance(self.op, ops.Relation): - self._collect(self.op, toplevel=True) - else: - self.select_set = [self.op] - - def _collect(self, op, toplevel=False): - method = f"_collect_{type(op).__name__}" - - if hasattr(self, method): - f = getattr(self, method) - f(op, toplevel=toplevel) - elif isinstance(op, (ops.PhysicalTable, ops.SQLQueryResult)): - self._collect_PhysicalTable(op, toplevel=toplevel) - elif isinstance(op, ops.JoinChain): - self._collect_Join(op, toplevel=toplevel) - elif isinstance(op, ops.WindowingTVF): - self._collect_WindowingTVF(op, toplevel=toplevel) - else: - raise NotImplementedError(type(op)) - - def _collect_Distinct(self, op, toplevel=False): - if toplevel: - self.distinct = True - - self._collect(op.table, toplevel=toplevel) - - def _collect_Limit(self, op, toplevel=False): - if toplevel: - if isinstance(table := op.parent, ops.Limit): - self.table_set = table - self.select_set = [table] - else: - self._collect(table, toplevel=toplevel) - - assert self.limit is None - self.limit = _LimitSpec(op.n, op.offset) - - def _collect_Sample(self, op, toplevel=False): - if toplevel: - self.table_set = op - self.select_set = [op] - - def _collect_Union(self, op, toplevel=False): - if toplevel: - self.table_set = op - self.select_set = [op] - - def _collect_Difference(self, op, toplevel=False): - if toplevel: - self.table_set = op - self.select_set = [op] - - def _collect_Intersection(self, op, toplevel=False): - if toplevel: - self.table_set = op - self.select_set = [op] - - def _collect_Aggregation(self, op, toplevel=False): - # The select set includes the grouping keys (if any), and these are - # duplicated in the group_by set. SQL translator can decide how to - # format these depending on the database. Most likely the - # GROUP BY 1, 2, ... style - if toplevel: - self.group_by = self._convert_group_by(op.by) - self.having = op.having - self.select_set = op.by + op.metrics - self.table_set = op.table - self.filters = op.predicates - self.order_by = op.sort_keys - - self._collect(op.table) - - def _collect_Project(self, op, toplevel=False): - table = op.parent - - if toplevel: - if isinstance(table, ops.JoinChain): - self._collect_Join(table) - else: - self._collect(table) - - selections = op.values - self.select_set = list(selections.values()) - self.table_set = table - - def _collect_InMemoryTable(self, node, toplevel=False): - if toplevel: - self.select_set = [node] - self.table_set = node - - def _convert_group_by(self, nodes): - return list(range(len(nodes))) - - def _collect_Join(self, op, toplevel=False): - if toplevel: - self.table_set = op - self.select_set = [op] - - def _collect_PhysicalTable(self, op, toplevel=False): - if toplevel: - self.select_set = [op] - self.table_set = op - - def _collect_DummyTable(self, op, toplevel=False): - if toplevel: - self.select_set = list(op.values) - self.table_set = None - - def _collect_SelfReference(self, op, toplevel=False): - if toplevel: - self._collect(op.table, toplevel=toplevel) - - def _collect_WindowingTVF(self, op, toplevel=False): - if toplevel: - self.table_set = op - self.select_set = [op] - - # -------------------------------------------------------------------- - # Subquery analysis / extraction - - def _analyze_subqueries(self): - # Somewhat temporary place for this. A little bit tricky, because - # subqueries can be found in many places - # - With the table set - # - Inside the where clause (these may be able to place directly, some - # cases not) - # - As support queries inside certain expressions (possibly needing to - # be extracted and joined into the table set where they are - # used). More complex transformations should probably not occur here, - # though. - # - # Duplicate subqueries might appear in different parts of the query - # structure, e.g. beneath two aggregates that are joined together, so - # we have to walk the entire query structure. - # - # The default behavior is to only extract into a WITH clause when a - # subquery appears multiple times (for DRY reasons). At some point we - # can implement a more aggressive policy so that subqueries always - # appear in the WITH part of the SELECT statement, if that's what you - # want. - - # Find the subqueries, and record them in the passed query context. - subqueries = an.find_subqueries( - [self.table_set, *self.filters], min_dependents=2 - ) - - self.subqueries = [] - for node in subqueries: - # See #173. Might have been extracted already in a parent context. - if not self.context.is_extracted(node): - self.subqueries.append(node) - self.context.set_extracted(node) diff --git a/ibis/backends/base/sql/compiler/translator.py b/ibis/backends/base/sql/compiler/translator.py deleted file mode 100644 index 2df24b4b93a4..000000000000 --- a/ibis/backends/base/sql/compiler/translator.py +++ /dev/null @@ -1,386 +0,0 @@ -from __future__ import annotations - -import contextlib -import itertools -from typing import TYPE_CHECKING, Callable - -import ibis -import ibis.common.exceptions as com -import ibis.expr.operations as ops -from ibis.backends.base.sql.registry import operation_registry, quote_identifier - -if TYPE_CHECKING: - from collections.abc import Iterable, Iterator - - -class QueryContext: - """Records bits of information used during ibis AST to SQL translation. - - Notably, table aliases (for subquery construction) and scalar query - parameters are tracked here. - """ - - def __init__(self, compiler, indent=2, parent=None, params=None): - self.compiler = compiler - self.table_refs = {} - self.extracted_subexprs = set() - self.subquery_memo = {} - self.indent = indent - self.parent = parent - self.always_alias = True - self.query = None - self.params = params if params is not None else {} - self._alias_counter = getattr(parent, "_alias_counter", 0) - - def _compile_subquery(self, op): - sub_ctx = self.subcontext() - return self._to_sql(op, sub_ctx) - - def _to_sql(self, expr, ctx): - return self.compiler.to_sql(expr, ctx) - - def collapse(self, queries: Iterable[str]) -> str: - """Turn an iterable of queries into something executable. - - Parameters - ---------- - queries - Iterable of query strings - - Returns - ------- - query - A single query string - - """ - return "\n\n".join(queries) - - @property - def top_context(self): - if self.parent is None: - return self - else: - return self.parent.top_context - - def set_always_alias(self): - self.always_alias = True - - def get_compiled_expr(self, node): - with contextlib.suppress(KeyError): - return self.top_context.subquery_memo[node] - - if isinstance(node, (ops.SQLQueryResult, ops.SQLStringView)): - result = node.query - else: - result = self._compile_subquery(node) - - self.top_context.subquery_memo[node] = result - return result - - def _next_alias(self) -> str: - alias = self._alias_counter - self._alias_counter += 1 - return f"t{alias:d}" - - def make_alias(self, node): - # check for existing tables that we're referencing from a parent context - for ctx in itertools.islice(self._contexts(), 1, None): - if (alias := ctx.table_refs.get(node)) is not None: - self.set_ref(node, alias) - return - - self.set_ref(node, self._next_alias()) - - def _contexts( - self, - *, - parents: bool = True, - ) -> Iterator[QueryContext]: - ctx = self - yield ctx - while parents and ctx.parent is not None: - ctx = ctx.parent - yield ctx - - def has_ref(self, node, parent_contexts=False): - return any( - node in ctx.table_refs for ctx in self._contexts(parents=parent_contexts) - ) - - def set_ref(self, node, alias): - self.table_refs[node] = alias - - def get_ref(self, node, search_parents=False): - """Return the alias used to refer to an expression.""" - assert isinstance(node, ops.Node), type(node) - - if (ref := self.table_refs.get(node)) is not None: - return ref - - if self.is_extracted(node): - return self.top_context.table_refs.get(node) - - if search_parents and (parent := self.parent) is not None: - return parent.get_ref(node, search_parents=search_parents) - - return None - - def is_extracted(self, node): - return node in self.top_context.extracted_subexprs - - def set_extracted(self, node): - self.extracted_subexprs.add(node) - self.make_alias(node) - - def subcontext(self): - return self.__class__( - compiler=self.compiler, - indent=self.indent, - parent=self, - params=self.params, - ) - - # Maybe temporary hacks for correlated / uncorrelated subqueries - - def set_query(self, query): - self.query = query - - def is_foreign_expr(self, op): - from ibis.expr.analysis import shares_all_roots - - # The expression isn't foreign to us. For example, the parent table set - # in a correlated WHERE subquery - if self.has_ref(op, parent_contexts=True): - return False - - parents = [self.query.table_set] + self.query.select_set - return not shares_all_roots(op, parents) - - -class ExprTranslator: - """Translates ibis expressions into a compilation target.""" - - _registry = operation_registry - _rewrites: dict[ops.Node, Callable] = {} - _forbids_frame_clause = ( - ops.DenseRank, - ops.MinRank, - ops.NTile, - ops.PercentRank, - ops.CumeDist, - ops.RowNumber, - ) - _require_order_by = ( - ops.Lag, - ops.Lead, - ops.DenseRank, - ops.MinRank, - ops.FirstValue, - ops.LastValue, - ops.PercentRank, - ops.CumeDist, - ops.NTile, - ) - _unsupported_reductions = ( - ops.ApproxMedian, - ops.GroupConcat, - ops.ApproxCountDistinct, - ) - _dialect_name = "hive" - _quote_identifiers = None - _bool_aggs_need_cast_to_int32 = False - - def __init__( - self, node, context, named=False, permit_subquery=False, within_where=False - ): - self.node = node - self.permit_subquery = permit_subquery - - assert context is not None, f"context is None in {type(self).__name__}" - self.context = context - - # For now, governing whether the result will have a name - self.named = named - - # used to indicate whether the expression being rendered is within a - # WHERE clause. This is used for MSSQL to determine whether to use - # boolean expressions or not. - self.within_where = within_where - - def _needs_name(self, op): - if not self.named: - return False - - if isinstance(op, ops.TableColumn): - # This column has been given an explicitly different name - return False - - return bool(op.name) - - def name(self, translated, name, force=True): - return f"{translated} AS {quote_identifier(name, force=force)}" - - def get_result(self): - """Compile SQL expression into a string.""" - translated = self.translate(self.node) - if self._needs_name(self.node): - # TODO: this could fail in various ways - translated = self.name(translated, self.node.name) - return translated - - @classmethod - def add_operation(cls, operation, translate_function): - """Add an operation to the operation registry. - - In general, operations should be defined directly in the registry, in - `registry.py`. There are couple of exceptions why this is needed. - - Operations defined by Ibis users (not Ibis or backend developers), and - UDFs which are added dynamically. - """ - cls._registry[operation] = translate_function - - def translate(self, op): - assert isinstance(op, ops.Node), type(op) - - if type(op) in self._rewrites: # even if type(op) is in self._registry - op = self._rewrites[type(op)](op) - - # TODO: use op MRO for subclasses instead of this isinstance spaghetti - if isinstance(op, ops.ScalarParameter): - return self._trans_param(op) - elif isinstance(op, ops.TableNode): - # HACK/TODO: revisit for more complex cases - return "*" - elif type(op) in self._registry: - formatter = self._registry[type(op)] - return formatter(self, op) - else: - raise com.OperationNotDefinedError(f"No translation rule for {type(op)}") - - def _trans_param(self, op): - raw_value = self.context.params[op] - dtype = op.dtype - if dtype.is_struct(): - literal = ibis.struct(raw_value, type=dtype) - elif dtype.is_map(): - literal = ibis.map(list(raw_value.keys()), list(raw_value.values())) - else: - literal = ibis.literal(raw_value, type=dtype) - return self.translate(literal.op()) - - @classmethod - def rewrites(cls, klass): - def decorator(f): - cls._rewrites[klass] = f - return f - - return decorator - - -rewrites = ExprTranslator.rewrites - - -@rewrites(ops.Bucket) -def _bucket(op): - # TODO(kszucs): avoid the expression roundtrip - expr = op.arg.to_expr() - stmt = ibis.case() - - if op.closed == "left": - l_cmp = ops.LessEqual - r_cmp = ops.Less - else: - l_cmp = ops.Less - r_cmp = ops.LessEqual - - user_num_buckets = len(op.buckets) - 1 - - bucket_id = 0 - if op.include_under: - if user_num_buckets > 0: - cmp = ops.Less if op.close_extreme else r_cmp - else: - cmp = ops.LessEqual if op.closed == "right" else ops.Less - stmt = stmt.when(cmp(op.arg, op.buckets[0]).to_expr(), bucket_id) - bucket_id += 1 - - for j, (lower, upper) in enumerate(zip(op.buckets, op.buckets[1:])): - if op.close_extreme and ( - (op.closed == "right" and j == 0) - or (op.closed == "left" and j == (user_num_buckets - 1)) - ): - stmt = stmt.when( - ops.And( - ops.LessEqual(lower, op.arg), ops.LessEqual(op.arg, upper) - ).to_expr(), - bucket_id, - ) - else: - stmt = stmt.when( - ops.And(l_cmp(lower, op.arg), r_cmp(op.arg, upper)).to_expr(), - bucket_id, - ) - bucket_id += 1 - - if op.include_over: - if user_num_buckets > 0: - cmp = ops.Less if op.close_extreme else l_cmp - else: - cmp = ops.Less if op.closed == "right" else ops.LessEqual - - stmt = stmt.when(cmp(op.buckets[-1], op.arg).to_expr(), bucket_id) - bucket_id += 1 - - result = stmt.end() - if expr.has_name(): - result = result.name(expr.get_name()) - - return result.op() - - -@rewrites(ops.Any) -def _any_expand(op): - return ops.Max(op.arg, where=op.where) - - -@rewrites(ops.All) -def _all_expand(op): - return ops.Min(op.arg, where=op.where) - - -@rewrites(ops.Cast) -def _rewrite_cast(op): - # TODO(kszucs): avoid the expression roundtrip - if op.to.is_interval() and op.arg.dtype.is_integer(): - return op.arg.to_expr().to_interval(unit=op.to.unit).op() - return op - - -@rewrites(ops.StringContains) -def _rewrite_string_contains(op): - return ops.GreaterEqual(ops.StringFind(op.haystack, op.needle), 0) - - -@rewrites(ops.Clip) -def _rewrite_clip(op): - dtype = op.dtype - arg = ops.Cast(op.arg, dtype) - - arg_is_null = ops.IsNull(arg) - - if (upper := op.upper) is not None: - clipped_lower = ops.Least((arg, ops.Cast(upper, dtype))) - if dtype.nullable: - arg = ops.IfElse(arg_is_null, arg, clipped_lower) - else: - arg = clipped_lower - - if (lower := op.lower) is not None: - clipped_upper = ops.Greatest((arg, ops.Cast(lower, dtype))) - if dtype.nullable: - arg = ops.IfElse(arg_is_null, arg, clipped_upper) - else: - arg = clipped_upper - - return arg diff --git a/ibis/backends/base/sql/ddl.py b/ibis/backends/base/sql/ddl.py deleted file mode 100644 index 0f88a93a1df6..000000000000 --- a/ibis/backends/base/sql/ddl.py +++ /dev/null @@ -1,469 +0,0 @@ -from __future__ import annotations - -import re - -import sqlglot as sg - -import ibis.expr.datatypes as dt -import ibis.expr.schema as sch -from ibis.backends.base.sql.compiler import DDL, DML -from ibis.backends.base.sql.registry import quote_identifier, type_to_sql_string - -fully_qualified_re = re.compile(r"(.*)\.(?:`(.*)`|(.*))") -_format_aliases = {"TEXT": "TEXTFILE"} - - -def _sanitize_format(format): - if format is None: - return None - format = format.upper() - format = _format_aliases.get(format, format) - if format not in ("PARQUET", "AVRO", "TEXTFILE"): - raise ValueError(f"Invalid format: {format!r}") - - return format - - -def is_fully_qualified(x): - return bool(fully_qualified_re.search(x)) - - -def _is_quoted(x): - regex = re.compile(r"(?:`(.*)`|(.*))") - quoted, _ = regex.match(x).groups() - return quoted is not None - - -def format_schema(schema): - elements = [ - _format_schema_element(name, t) for name, t in zip(schema.names, schema.types) - ] - return "({})".format(",\n ".join(elements)) - - -def _format_schema_element(name, t): - return f"{quote_identifier(name, force=True)} {type_to_sql_string(t)}" - - -def _format_partition_kv(k, v, type): - if type == dt.string: - value_formatted = f'"{v}"' - else: - value_formatted = str(v) - - return f"{k}={value_formatted}" - - -def format_partition(partition, partition_schema): - tokens = [] - if isinstance(partition, dict): - for name in partition_schema: - if name in partition: - tok = _format_partition_kv( - name, partition[name], partition_schema[name] - ) - else: - # dynamic partitioning - tok = name - tokens.append(tok) - else: - for name, value in zip(partition_schema, partition): - tok = _format_partition_kv(name, value, partition_schema[name]) - tokens.append(tok) - - return "PARTITION ({})".format(", ".join(tokens)) - - -def _format_properties(props): - tokens = [] - for k, v in sorted(props.items()): - tokens.append(f" '{k}'='{v}'") - - return "(\n{}\n)".format(",\n".join(tokens)) - - -def format_tblproperties(props): - formatted_props = _format_properties(props) - return f"TBLPROPERTIES {formatted_props}" - - -def _serdeproperties(props): - formatted_props = _format_properties(props) - return f"SERDEPROPERTIES {formatted_props}" - - -class _BaseQualifiedSQLStatement: - def _get_scoped_name(self, obj_name, database): - if is_fully_qualified(obj_name): - return obj_name - if _is_quoted(obj_name): - obj_name = obj_name[1:-1] - return sg.table(obj_name, db=database, quoted=True).sql(dialect="hive") - - -class BaseDDL(DDL, _BaseQualifiedSQLStatement): - pass - - -class _BaseDML(DML, _BaseQualifiedSQLStatement): - pass - - -class _CreateDDL(BaseDDL): - def _if_exists(self): - return "IF NOT EXISTS " if self.can_exist else "" - - -class CreateTable(_CreateDDL): - def __init__( - self, - table_name, - database=None, - external=False, - format="parquet", - can_exist=False, - partition=None, - path=None, - tbl_properties=None, - ): - self.table_name = table_name - self.database = database - self.partition = partition - self.path = path - self.external = external - self.can_exist = can_exist - self.format = _sanitize_format(format) - self.tbl_properties = tbl_properties - - @property - def _prefix(self): - if self.external: - return "CREATE EXTERNAL TABLE" - else: - return "CREATE TABLE" - - def _create_line(self): - scoped_name = self._get_scoped_name(self.table_name, self.database) - return f"{self._prefix} {self._if_exists()}{scoped_name}" - - def _location(self): - return f"LOCATION '{self.path}'" if self.path else None - - def _storage(self): - # By the time we're here, we have a valid format - return f"STORED AS {self.format}" - - @property - def pieces(self): - yield self._create_line() - yield from filter(None, self._pieces) - - def compile(self): - return "\n".join(self.pieces) - - -class CTAS(CreateTable): - """Create Table As Select.""" - - def __init__( - self, - table_name, - select, - database=None, - external=False, - format="parquet", - can_exist=False, - path=None, - partition=None, - ): - super().__init__( - table_name, - database=database, - external=external, - format=format, - can_exist=can_exist, - path=path, - partition=partition, - ) - self.select = select - - @property - def _pieces(self): - yield self._partitioned_by() - yield self._storage() - yield self._location() - yield "AS" - yield self.select - - def _partitioned_by(self): - if self.partition is not None: - return "PARTITIONED BY ({})".format( - ", ".join(quote_identifier(expr.get_name()) for expr in self.partition) - ) - return None - - -class CreateView(CTAS): - """Create a view.""" - - def __init__(self, table_name, select, database=None, can_exist=False): - super().__init__(table_name, select, database=database, can_exist=can_exist) - - @property - def _pieces(self): - yield "AS" - yield self.select - - @property - def _prefix(self): - return "CREATE VIEW" - - -class CreateTableWithSchema(CreateTable): - def __init__(self, table_name, schema, table_format=None, **kwargs): - super().__init__(table_name, **kwargs) - self.schema = schema - self.table_format = table_format - - @property - def _pieces(self): - if self.partition is not None: - main_schema = self.schema - part_schema = self.partition - if not isinstance(part_schema, sch.Schema): - part_fields = {name: self.schema[name] for name in part_schema} - part_schema = sch.Schema(part_fields) - - to_delete = {name for name in self.partition if name in self.schema} - fields = { - name: dtype - for name, dtype in main_schema.items() - if name not in to_delete - } - main_schema = sch.Schema(fields) - - yield format_schema(main_schema) - yield f"PARTITIONED BY {format_schema(part_schema)}" - else: - yield format_schema(self.schema) - - if self.table_format is not None: - yield "\n".join(self.table_format.to_ddl()) - else: - yield self._storage() - - yield self._location() - - -class CreateDatabase(_CreateDDL): - def __init__(self, name, path=None, can_exist=False): - self.name = name - self.path = path - self.can_exist = can_exist - - def compile(self): - name = quote_identifier(self.name) - - create_decl = "CREATE DATABASE" - create_line = f"{create_decl} {self._if_exists()}{name}" - if self.path is not None: - create_line += f"\nLOCATION '{self.path}'" - - return create_line - - -class DropObject(BaseDDL): - def __init__(self, must_exist=True): - self.must_exist = must_exist - - def compile(self): - if_exists = "" if self.must_exist else "IF EXISTS " - object_name = self._object_name() - return f"DROP {self._object_type} {if_exists}{object_name}" - - -class DropDatabase(DropObject): - _object_type = "DATABASE" - - def __init__(self, name, must_exist=True): - super().__init__(must_exist=must_exist) - self.name = name - - def _object_name(self): - return self.name - - -class DropTable(DropObject): - _object_type = "TABLE" - - def __init__(self, table_name, database=None, must_exist=True): - super().__init__(must_exist=must_exist) - self.table_name = table_name - self.database = database - - def _object_name(self): - return self._get_scoped_name(self.table_name, self.database) - - -class DropView(DropTable): - _object_type = "VIEW" - - -class TruncateTable(BaseDDL): - _object_type = "TABLE" - - def __init__(self, table_name, database=None): - self.table_name = table_name - self.database = database - - def compile(self): - name = self._get_scoped_name(self.table_name, self.database) - return f"TRUNCATE TABLE {name}" - - -class InsertSelect(_BaseDML): - def __init__( - self, - table_name, - select_expr, - database=None, - partition=None, - partition_schema=None, - overwrite=False, - ): - self.table_name = table_name - self.database = database - self.select = select_expr - - self.partition = partition - self.partition_schema = partition_schema - - self.overwrite = overwrite - - def compile(self): - if self.overwrite: - cmd = "INSERT OVERWRITE" - else: - cmd = "INSERT INTO" - - if self.partition is not None: - part = format_partition(self.partition, self.partition_schema) - partition = f" {part} " - else: - partition = "" - - select_query = self.select - scoped_name = self._get_scoped_name(self.table_name, self.database) - return f"{cmd} {scoped_name}{partition}\n{select_query}" - - -class AlterTable(BaseDDL): - def __init__( - self, - table, - location=None, - format=None, - tbl_properties=None, - serde_properties=None, - ): - self.table = table - self.location = location - self.format = _sanitize_format(format) - self.tbl_properties = tbl_properties - self.serde_properties = serde_properties - - def _wrap_command(self, cmd): - return f"ALTER TABLE {cmd}" - - def _format_properties(self, prefix=""): - tokens = [] - - if self.location is not None: - tokens.append(f"LOCATION '{self.location}'") - - if self.format is not None: - tokens.append(f"FILEFORMAT {self.format}") - - if self.tbl_properties is not None: - tokens.append(format_tblproperties(self.tbl_properties)) - - if self.serde_properties is not None: - tokens.append(_serdeproperties(self.serde_properties)) - - if len(tokens) > 0: - return "\n{}{}".format(prefix, "\n".join(tokens)) - else: - return "" - - def compile(self): - props = self._format_properties() - action = f"{self.table} SET {props}" - return self._wrap_command(action) - - -class DropFunction(DropObject): - def __init__(self, name, inputs, must_exist=True, aggregate=False, database=None): - super().__init__(must_exist=must_exist) - self.name = name - self.inputs = tuple(map(dt.dtype, inputs)) - self.must_exist = must_exist - self.aggregate = aggregate - self.database = database - - def _object_name(self): - return self.name - - def compile(self): - tokens = ["DROP"] - if self.aggregate: - tokens.append("AGGREGATE") - tokens.append("FUNCTION") - if not self.must_exist: - tokens.append("IF EXISTS") - - tokens.append(self._impala_signature()) - return " ".join(tokens) - - -class RenameTable(AlterTable): - def __init__( - self, - old_name: str, - new_name: str, - old_database: str | None = None, - new_database: str | None = None, - dialect: str = "hive", - ): - self._old = sg.table(old_name, db=old_database, quoted=True).sql( - dialect=dialect - ) - self._new = sg.table(new_name, db=new_database, quoted=True).sql( - dialect=dialect - ) - - def compile(self): - return self._wrap_command(f"{self._old} RENAME TO {self._new}") - - -__all__ = ( - "fully_qualified_re", - "is_fully_qualified", - "format_schema", - "format_partition", - "format_tblproperties", - "BaseDDL", - "CreateTable", - "CTAS", - "CreateView", - "CreateTableWithSchema", - "CreateDatabase", - "DropObject", - "DropDatabase", - "DropTable", - "DropView", - "TruncateTable", - "InsertSelect", - "AlterTable", - "DropFunction", - "RenameTable", -) diff --git a/ibis/backends/base/sql/registry/__init__.py b/ibis/backends/base/sql/registry/__init__.py deleted file mode 100644 index 2a188bf989c8..000000000000 --- a/ibis/backends/base/sql/registry/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from ibis.backends.base.sql.registry.aggregate import reduction -from ibis.backends.base.sql.registry.helpers import ( - quote_identifier, - sql_type_names, - type_to_sql_string, -) -from ibis.backends.base.sql.registry.literal import literal, literal_formatters -from ibis.backends.base.sql.registry.main import ( - binary_infix_ops, - fixed_arity, - operation_registry, - unary, -) -from ibis.backends.base.sql.registry.window import ( - format_window_frame, - time_range_to_range_window, -) - -__all__ = ( - "quote_identifier", - "operation_registry", - "binary_infix_ops", - "fixed_arity", - "literal", - "literal_formatters", - "sql_type_names", - "type_to_sql_string", - "reduction", - "unary", - "format_window_frame", - "time_range_to_range_window", -) diff --git a/ibis/backends/base/sql/registry/aggregate.py b/ibis/backends/base/sql/registry/aggregate.py deleted file mode 100644 index 23e1ffd96647..000000000000 --- a/ibis/backends/base/sql/registry/aggregate.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import ibis -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops - - -def _maybe_cast_bool(translator, op, arg): - if ( - translator._bool_aggs_need_cast_to_int32 - and isinstance(op, (ops.Sum, ops.Mean, ops.Min, ops.Max)) - and (dtype := arg.dtype).is_boolean() - ): - return ops.Cast(arg, dt.Int32(nullable=dtype.nullable)) - return arg - - -def _reduction_format(translator, op, func_name, where, *args): - args = ( - _maybe_cast_bool(translator, op, arg) - for arg in args - if isinstance(arg, ops.Node) - ) - if where is not None: - args = (ops.IfElse(where, arg, ibis.NA) for arg in args) - - return "{}({})".format( - func_name, - ", ".join(map(translator.translate, args)), - ) - - -def reduction(func_name): - def formatter(translator, op): - *args, where = op.args - return _reduction_format(translator, op, func_name, where, *args) - - return formatter - - -def variance_like(func_name): - func_names = { - "sample": f"{func_name}_samp", - "pop": f"{func_name}_pop", - } - - def formatter(translator, op): - return _reduction_format(translator, op, func_names[op.how], op.where, op.arg) - - return formatter - - -def count_distinct(translator, op): - if op.where is not None: - arg_formatted = translator.translate(ops.IfElse(op.where, op.arg, None)) - else: - arg_formatted = translator.translate(op.arg) - return f"count(DISTINCT {arg_formatted})" diff --git a/ibis/backends/base/sql/registry/binary_infix.py b/ibis/backends/base/sql/registry/binary_infix.py deleted file mode 100644 index 388141e0b90c..000000000000 --- a/ibis/backends/base/sql/registry/binary_infix.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import ibis.expr.analysis as an -from ibis.backends.base.sql.registry import helpers - - -def binary_infix_op(infix_sym): - def formatter(translator, op): - left, right = op.args - - left_arg = translator.translate(left) - right_arg = translator.translate(right) - if helpers.needs_parens(left): - left_arg = helpers.parenthesize(left_arg) - - if helpers.needs_parens(right): - right_arg = helpers.parenthesize(right_arg) - - return f"{left_arg} {infix_sym} {right_arg}" - - return formatter - - -def identical_to(translator, op): - if op.args[0].equals(op.args[1]): - return "TRUE" - - left = translator.translate(op.left) - right = translator.translate(op.right) - - if helpers.needs_parens(op.left): - left = helpers.parenthesize(left) - if helpers.needs_parens(op.right): - right = helpers.parenthesize(right) - return f"{left} IS NOT DISTINCT FROM {right}" - - -def xor(translator, op): - left_arg = translator.translate(op.left) - right_arg = translator.translate(op.right) - - if helpers.needs_parens(op.left): - left_arg = helpers.parenthesize(left_arg) - - if helpers.needs_parens(op.right): - right_arg = helpers.parenthesize(right_arg) - - return f"({left_arg} OR {right_arg}) AND NOT ({left_arg} AND {right_arg})" - - -def in_values(translator, op): - if not op.options: - return "FALSE" - - left = translator.translate(op.value) - if helpers.needs_parens(op.value): - left = helpers.parenthesize(left) - - values = [translator.translate(x) for x in op.options] - right = helpers.parenthesize(", ".join(values)) - - # we explicitly do NOT parenthesize the right side because it doesn't - # make sense to do so for Sequence operations - return f"{left} IN {right}" - - -def in_column(translator, op): - from ibis.backends.base.sql.registry.main import table_array_view - - ctx = translator.context - - left = translator.translate(op.value) - if helpers.needs_parens(op.value): - left = helpers.parenthesize(left) - - right = translator.translate(op.options) - if not any( - ctx.is_foreign_expr(leaf) - for leaf in an.find_immediate_parent_tables(op.options) - ): - array = op.options.to_expr().as_table().to_array().op() - right = table_array_view(translator, array) - else: - right = translator.translate(op.options) - - # we explicitly do NOT parenthesize the right side because it doesn't - # make sense to do so for Sequence operations - return f"{left} IN {right}" diff --git a/ibis/backends/base/sql/registry/case.py b/ibis/backends/base/sql/registry/case.py deleted file mode 100644 index da7bbb5e0b12..000000000000 --- a/ibis/backends/base/sql/registry/case.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from io import StringIO - - -class _CaseFormatter: - def __init__(self, translator, base, cases, results, default): - self.translator = translator - self.base = base - self.cases = cases - self.results = results - self.default = default - - # HACK - self.indent = 2 - self.multiline = len(cases) > 1 - self.buf = StringIO() - - def get_result(self): - self.buf.seek(0) - - self.buf.write("CASE") - if self.base is not None: - base_str = self.translator.translate(self.base) - self.buf.write(f" {base_str}") - - for case, result in zip(self.cases, self.results): - self._next_case() - case_str = self.translator.translate(case) - result_str = self.translator.translate(result) - self.buf.write(f"WHEN {case_str} THEN {result_str}") - - if self.default is not None: - self._next_case() - default_str = self.translator.translate(self.default) - self.buf.write(f"ELSE {default_str}") - - if self.multiline: - self.buf.write("\nEND") - else: - self.buf.write(" END") - - return self.buf.getvalue() - - def _next_case(self): - if self.multiline: - self.buf.write("\n{}".format(" " * self.indent)) - else: - self.buf.write(" ") - - -def simple_case(translator, op): - formatter = _CaseFormatter(translator, op.base, op.cases, op.results, op.default) - return formatter.get_result() - - -def searched_case(translator, op): - formatter = _CaseFormatter(translator, None, op.cases, op.results, op.default) - return formatter.get_result() diff --git a/ibis/backends/base/sql/registry/helpers.py b/ibis/backends/base/sql/registry/helpers.py deleted file mode 100644 index ed5b4fa170b9..000000000000 --- a/ibis/backends/base/sql/registry/helpers.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -import ibis.common.exceptions as com -import ibis.expr.operations as ops -from ibis.backends.base.sql.registry import identifiers - - -def format_call(translator, func, *args): - formatted_args = [] - for arg in args: - fmt_arg = translator.translate(arg) - formatted_args.append(fmt_arg) - - return "{}({})".format(func, ", ".join(formatted_args)) - - -def quote_identifier(name, quotechar="`", force=False): - """Add quotes to the `name` identifier if needed.""" - if force or name.count(" ") or name in identifiers.base_identifiers: - return f"{quotechar}{name}{quotechar}" - else: - return name - - -_NEEDS_PARENS_OPS = ( - ops.Negate, - ops.IsNull, - ops.NotNull, - ops.Add, - ops.Subtract, - ops.Multiply, - ops.Divide, - ops.Power, - ops.Modulus, - ops.Equals, - ops.NotEquals, - ops.GreaterEqual, - ops.Greater, - ops.LessEqual, - ops.Less, - ops.IdenticalTo, - ops.And, - ops.Or, - ops.Xor, -) - - -def needs_parens(op: ops.Node): - if isinstance(op, ops.Alias): - op = op.arg - return isinstance(op, _NEEDS_PARENS_OPS) - - -parenthesize = "({})".format - - -sql_type_names = { - "int8": "tinyint", - "int16": "smallint", - "int32": "int", - "int64": "bigint", - "float": "float", - "float32": "float", - "double": "double", - "float64": "double", - "string": "string", - "boolean": "boolean", - "timestamp": "timestamp", - "decimal": "decimal", - "date": "date", -} - - -def type_to_sql_string(tval): - if tval.is_decimal(): - return f"decimal({tval.precision}, {tval.scale})" - name = tval.name.lower() - try: - return sql_type_names[name] - except KeyError: - raise com.UnsupportedBackendType(name) diff --git a/ibis/backends/base/sql/registry/identifiers.py b/ibis/backends/base/sql/registry/identifiers.py deleted file mode 100644 index 1129a3176698..000000000000 --- a/ibis/backends/base/sql/registry/identifiers.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -base_identifiers = [ - "add", - "aggregate", - "all", - "alter", - "and", - "api_version", - "as", - "asc", - "avro", - "between", - "bigint", - "binary", - "boolean", - "by", - "cached", - "case", - "cast", - "change", - "char", - "class", - "close_fn", - "column", - "columns", - "comment", - "compute", - "create", - "cross", - "data", - "database", - "databases", - "date", - "datetime", - "decimal", - "delimited", - "desc", - "describe", - "distinct", - "div", - "double", - "drop", - "else", - "end", - "escaped", - "exists", - "explain", - "external", - "false", - "fields", - "fileformat", - "finalize_fn", - "first", - "float", - "format", - "formatted", - "from", - "full", - "function", - "functions", - "group", - "having", - "if", - "in", - "incremental", - "init_fn", - "inner", - "inpath", - "insert", - "int", - "integer", - "intermediate", - "interval", - "into", - "invalidate", - "is", - "join", - "last", - "left", - "like", - "limit", - "lines", - "load", - "location", - "merge_fn", - "metadata", - "not", - "null", - "nulls", - "offset", - "on", - "or", - "order", - "outer", - "overwrite", - "parquet", - "parquetfile", - "partition", - "partitioned", - "partitions", - "prepare_fn", - "produced", - "rcfile", - "real", - "refresh", - "regexp", - "rename", - "replace", - "returns", - "right", - "rlike", - "row", - "schema", - "schemas", - "select", - "semi", - "sequencefile", - "serdeproperties", - "serialize_fn", - "set", - "show", - "smallint", - "stats", - "stored", - "straight_join", - "string", - "symbol", - "table", - "tables", - "tblproperties", - "terminated", - "textfile", - "then", - "timestamp", - "tinyint", - "to", - "true", - "uncached", - "union", - "update_fn", - "use", - "using", - "values", - "view", - "when", - "where", - "with", -] diff --git a/ibis/backends/base/sql/registry/literal.py b/ibis/backends/base/sql/registry/literal.py deleted file mode 100644 index b31aec6fac45..000000000000 --- a/ibis/backends/base/sql/registry/literal.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -import datetime -import math - -import ibis.expr.types as ir - - -def _set_literal_format(translator, expr): - value_type = expr.type().value_type - - formatted = [ - translator.translate(ir.literal(x, type=value_type)) for x in expr.op().value - ] - - return "(" + ", ".join(formatted) + ")" - - -def _boolean_literal_format(translator, op): - return "TRUE" if op.value else "FALSE" - - -def _string_literal_format(translator, op): - return "'{}'".format( - op.value - # Escape \ first so we don't double escape other characters. - .replace("\\", "\\\\") - # Escape ' since we're using those for the string literal. - .replace("'", "\\'") - # ASCII escape sequences that are recognized in Python: - # https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals - .replace("\a", "\\a") # Bell - .replace("\b", "\\b") # Backspace - .replace("\f", "\\f") # Formfeed - .replace("\n", "\\n") # Newline / Linefeed - .replace("\r", "\\r") # Carriage return - .replace("\t", "\\t") # Tab - .replace("\v", "\\v") # Vertical tab - ) - - -def _number_literal_format(translator, op): - if math.isfinite(op.value): - formatted = repr(op.value) - else: - if math.isnan(op.value): - formatted_val = "NaN" - elif math.isinf(op.value): - if op.value > 0: - formatted_val = "Infinity" - else: - formatted_val = "-Infinity" - formatted = f"CAST({formatted_val!r} AS DOUBLE)" - - return formatted - - -def _interval_literal_format(translator, op): - return f"INTERVAL {op.value} {op.dtype.resolution.upper()}" - - -def _date_literal_format(translator, op): - value = op.value - if isinstance(value, datetime.date): - value = value.isoformat() - - return repr(value) - - -def _timestamp_literal_format(translator, op): - value = op.value - if isinstance(value, datetime.datetime): - value = value.isoformat() - - return repr(value) - - -literal_formatters = { - "boolean": _boolean_literal_format, - "number": _number_literal_format, - "string": _string_literal_format, - "interval": _interval_literal_format, - "timestamp": _timestamp_literal_format, - "date": _date_literal_format, - "set": _set_literal_format, -} - - -def literal(translator, op): - """Return the expression as its literal value.""" - - dtype = op.dtype - - if op.value is None: - return "NULL" - - if dtype.is_boolean(): - typeclass = "boolean" - elif dtype.is_string() or dtype.is_inet() or dtype.is_macaddr(): - typeclass = "string" - elif dtype.is_date(): - typeclass = "date" - elif dtype.is_numeric(): - typeclass = "number" - elif dtype.is_timestamp(): - typeclass = "timestamp" - elif dtype.is_interval(): - typeclass = "interval" - else: - raise NotImplementedError(f"Unsupported type: {dtype!r}") - - return literal_formatters[typeclass](translator, op) diff --git a/ibis/backends/base/sql/registry/main.py b/ibis/backends/base/sql/registry/main.py deleted file mode 100644 index 96d74fa83013..000000000000 --- a/ibis/backends/base/sql/registry/main.py +++ /dev/null @@ -1,386 +0,0 @@ -from __future__ import annotations - -import ibis.common.exceptions as com -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.types as ir -from ibis import util -from ibis.backends.base.sql.registry import ( - aggregate, - binary_infix, - case, - helpers, - string, - timestamp, - window, -) -from ibis.backends.base.sql.registry.literal import literal - - -def alias(translator, op): - # just compile the underlying argument because the naming is handled - # by the translator for the top level expression - return translator.translate(op.arg) - - -def fixed_arity(func_name, arity): - def formatter(translator, op): - if arity != len(op.args): - raise com.IbisError("incorrect number of args") - return helpers.format_call(translator, func_name, *op.args) - - return formatter - - -def unary(func_name): - return fixed_arity(func_name, 1) - - -def not_null(translator, op): - formatted_arg = translator.translate(op.arg) - return f"{formatted_arg} IS NOT NULL" - - -def is_null(translator, op): - formatted_arg = translator.translate(op.arg) - return f"{formatted_arg} IS NULL" - - -def not_(translator, op): - formatted_arg = translator.translate(op.arg) - if helpers.needs_parens(op.arg): - formatted_arg = helpers.parenthesize(formatted_arg) - return f"NOT {formatted_arg}" - - -def negate(translator, op): - arg = op.args[0] - formatted_arg = translator.translate(arg) - if op.dtype.is_boolean(): - return not_(translator, op) - else: - if helpers.needs_parens(arg): - formatted_arg = helpers.parenthesize(formatted_arg) - return f"-{formatted_arg}" - - -def sign(translator, op): - translated_arg = translator.translate(op.arg) - dtype = op.dtype - translated_type = helpers.type_to_sql_string(dtype) - if not dtype.is_float32(): - return f"CAST(sign({translated_arg}) AS {translated_type})" - return f"sign({translated_arg})" - - -def hashbytes(translator, op): - how = op.how - - arg_formatted = translator.translate(op.arg) - - if how == "md5": - return f"md5({arg_formatted})" - elif how == "sha1": - return f"sha1({arg_formatted})" - elif how == "sha256": - return f"sha256({arg_formatted})" - elif how == "sha512": - return f"sha512({arg_formatted})" - else: - raise NotImplementedError(how) - - -def log(translator, op): - arg_formatted = translator.translate(op.arg) - - if op.base is None: - return f"ln({arg_formatted})" - - base_formatted = translator.translate(op.base) - return f"log({base_formatted}, {arg_formatted})" - - -def cast(translator, op): - arg_formatted = translator.translate(op.arg) - - if op.arg.dtype.is_temporal() and op.to.is_int64(): - return f"1000000 * unix_timestamp({arg_formatted})" - else: - sql_type = helpers.type_to_sql_string(op.to) - return f"CAST({arg_formatted} AS {sql_type})" - - -def varargs(func_name): - def varargs_formatter(translator, op): - return helpers.format_call(translator, func_name, *op.arg) - - return varargs_formatter - - -def between(translator, op): - comp = translator.translate(op.arg) - lower = translator.translate(op.lower_bound) - upper = translator.translate(op.upper_bound) - return f"{comp} BETWEEN {lower} AND {upper}" - - -def table_array_view(translator, op): - ctx = translator.context - query = ctx.get_compiled_expr(op.table) - return f"(\n{util.indent(query, ctx.indent)}\n)" - - -def table_column(translator, op): - quoted_name = helpers.quote_identifier(op.name, force=True) - - ctx = translator.context - - # If the column does not originate from the table set in the current SELECT - # context, we should format as a subquery - if translator.permit_subquery and ctx.is_foreign_expr(op.table): - # TODO(kszucs): avoid the expression roundtrip - proj_expr = op.table.to_expr().select([op.name]).to_array().op() - return table_array_view(translator, proj_expr) - - alias = ctx.get_ref(op.table, search_parents=True) - if alias is not None: - quoted_name = f"{alias}.{quoted_name}" - - return quoted_name - - -def exists_subquery(translator, op): - ctx = translator.context - - dummy = ir.literal(1).name("") - node = ops.Selection( - table=op.foreign_table, - selections=[dummy], - predicates=op.predicates, - ) - subquery = ctx.get_compiled_expr(node) - - return f"EXISTS (\n{util.indent(subquery, ctx.indent)}\n)" - - -# XXX this is not added to operation_registry, but looks like impala is -# using it in the tests, and it works, even if it's not imported anywhere -def _round(translator, op): - arg, digits = op.args - - arg_formatted = translator.translate(arg) - - if digits is not None: - digits_formatted = translator.translate(digits) - return f"round({arg_formatted}, {digits_formatted})" - rounded = f"round({arg_formatted})" - if op.dtype.is_integer(): - return f"cast({rounded} AS BIGINT)" - return round - - -def concat(translator, op): - joined_args = ", ".join(map(translator.translate, op.arg)) - return f"concat({joined_args})" - - -def sort_key(translator, op): - suffix = "ASC" if op.ascending else "DESC" - return f"{translator.translate(op.expr)} {suffix}" - - -def count_star(translator, op): - return aggregate._reduction_format( - translator, - op, - "count", - op.where, - ops.Literal(value=1, dtype=dt.int64), - ) - - -def _ceil(t, op): - ceil = f"ceil({t.translate(op.arg)})" - if op.dtype.is_integer(): - return f"cast({ceil} AS BIGINT)" - return ceil - - -def _floor(t, op): - floor = f"floor({t.translate(op.arg)})" - if op.dtype.is_integer(): - return f"cast({floor} AS BIGINT)" - return floor - - -binary_infix_ops = { - # Binary operations - ops.Add: binary_infix.binary_infix_op("+"), - ops.Subtract: binary_infix.binary_infix_op("-"), - ops.Multiply: binary_infix.binary_infix_op("*"), - ops.Divide: binary_infix.binary_infix_op("/"), - ops.Power: fixed_arity("pow", 2), - ops.Modulus: binary_infix.binary_infix_op("%"), - # Comparisons - ops.Equals: binary_infix.binary_infix_op("="), - ops.NotEquals: binary_infix.binary_infix_op("!="), - ops.GreaterEqual: binary_infix.binary_infix_op(">="), - ops.Greater: binary_infix.binary_infix_op(">"), - ops.LessEqual: binary_infix.binary_infix_op("<="), - ops.Less: binary_infix.binary_infix_op("<"), - ops.IdenticalTo: binary_infix.identical_to, - # Boolean comparisons - ops.And: binary_infix.binary_infix_op("AND"), - ops.Or: binary_infix.binary_infix_op("OR"), - ops.Xor: binary_infix.xor, - # Bitwise operations - ops.BitwiseAnd: fixed_arity("bitand", 2), - ops.BitwiseOr: fixed_arity("bitor", 2), - ops.BitwiseXor: fixed_arity("bitxor", 2), - ops.BitwiseLeftShift: fixed_arity("shiftleft", 2), - ops.BitwiseRightShift: fixed_arity("shiftright", 2), - ops.BitwiseNot: unary("bitnot"), -} - -operation_registry = { - ops.Alias: alias, - # Unary operations - ops.NotNull: not_null, - ops.IsNull: is_null, - ops.Negate: negate, - ops.Not: not_, - ops.IsNan: unary("is_nan"), - ops.IsInf: unary("is_inf"), - ops.NullIf: fixed_arity("nullif", 2), - ops.Abs: unary("abs"), - ops.BaseConvert: fixed_arity("conv", 3), - ops.Ceil: _ceil, - ops.Floor: _floor, - ops.Exp: unary("exp"), - ops.Round: _round, - ops.Sign: sign, - ops.Sqrt: unary("sqrt"), - ops.HashBytes: hashbytes, - ops.RandomScalar: lambda *_: "rand(utc_to_unix_micros(utc_timestamp()))", - ops.Log: log, - ops.Ln: unary("ln"), - ops.Log2: unary("log2"), - ops.Log10: unary("log10"), - ops.Acos: unary("acos"), - ops.Asin: unary("asin"), - ops.Atan: unary("atan"), - ops.Atan2: fixed_arity("atan2", 2), - ops.Cos: unary("cos"), - ops.Cot: unary("cot"), - ops.Sin: unary("sin"), - ops.Tan: unary("tan"), - ops.Pi: fixed_arity("pi", 0), - ops.E: fixed_arity("e", 0), - ops.Degrees: lambda t, - op: f"(180 * {t.translate(op.arg)} / {t.translate(ops.Pi())})", - ops.Radians: lambda t, - op: f"({t.translate(ops.Pi())} * {t.translate(op.arg)} / 180)", - # Unary aggregates - ops.ApproxMedian: aggregate.reduction("appx_median"), - ops.ApproxCountDistinct: aggregate.reduction("ndv"), - ops.Mean: aggregate.reduction("avg"), - ops.Sum: aggregate.reduction("sum"), - ops.Max: aggregate.reduction("max"), - ops.Min: aggregate.reduction("min"), - ops.StandardDev: aggregate.variance_like("stddev"), - ops.Variance: aggregate.variance_like("var"), - ops.GroupConcat: aggregate.reduction("group_concat"), - ops.Count: aggregate.reduction("count"), - ops.CountStar: count_star, - ops.CountDistinct: aggregate.count_distinct, - # String operations - ops.StringConcat: concat, - ops.StringLength: unary("length"), - ops.StringAscii: unary("ascii"), - ops.Lowercase: unary("lower"), - ops.Uppercase: unary("upper"), - ops.Reverse: unary("reverse"), - ops.Strip: unary("trim"), - ops.LStrip: unary("ltrim"), - ops.RStrip: unary("rtrim"), - ops.Capitalize: unary("initcap"), - ops.Substring: string.substring, - ops.StrRight: fixed_arity("strright", 2), - ops.Repeat: fixed_arity("repeat", 2), - ops.StringFind: string.string_find, - ops.Translate: fixed_arity("translate", 3), - ops.FindInSet: string.find_in_set, - ops.LPad: fixed_arity("lpad", 3), - ops.RPad: fixed_arity("rpad", 3), - ops.StringJoin: string.string_join, - ops.StringSQLLike: string.string_like, - ops.StringSQLILike: string.string_ilike, - ops.RegexSearch: fixed_arity("regexp_like", 2), - ops.RegexExtract: fixed_arity("regexp_extract", 3), - ops.RegexReplace: fixed_arity("regexp_replace", 3), - ops.ExtractProtocol: string.extract_url_field("PROTOCOL"), - ops.ExtractAuthority: string.extract_url_field("AUTHORITY"), - ops.ExtractUserInfo: string.extract_url_field("USERINFO"), - ops.ExtractHost: string.extract_url_field("HOST"), - ops.ExtractFile: string.extract_url_field("FILE"), - ops.ExtractPath: string.extract_url_field("PATH"), - ops.ExtractQuery: string.extract_url_field("QUERY"), - ops.ExtractFragment: string.extract_url_field("REF"), - ops.StartsWith: string.startswith, - ops.EndsWith: string.endswith, - ops.StringReplace: fixed_arity("replace", 3), - # Timestamp operations - ops.Date: unary("to_date"), - ops.TimestampNow: lambda *args: "now()", - ops.ExtractYear: timestamp.extract_field("year"), - ops.ExtractMonth: timestamp.extract_field("month"), - ops.ExtractDay: timestamp.extract_field("day"), - ops.ExtractQuarter: timestamp.extract_field("quarter"), - ops.ExtractEpochSeconds: timestamp.extract_epoch_seconds, - ops.ExtractWeekOfYear: fixed_arity("weekofyear", 1), - ops.ExtractHour: timestamp.extract_field("hour"), - ops.ExtractMinute: timestamp.extract_field("minute"), - ops.ExtractSecond: timestamp.extract_field("second"), - ops.ExtractMicrosecond: timestamp.extract_microsecond, - ops.ExtractMillisecond: timestamp.extract_millisecond, - ops.TimestampTruncate: timestamp.truncate, - ops.DateTruncate: timestamp.truncate, - ops.IntervalFromInteger: timestamp.interval_from_integer, - # Other operations - ops.Literal: literal, - ops.Cast: cast, - ops.Coalesce: varargs("coalesce"), - ops.Greatest: varargs("greatest"), - ops.Least: varargs("least"), - ops.IfElse: fixed_arity("if", 3), - ops.Between: between, - ops.InValues: binary_infix.in_values, - ops.SimpleCase: case.simple_case, - ops.SearchedCase: case.searched_case, - ops.DateAdd: timestamp.timestamp_op("date_add"), - ops.DateSub: timestamp.timestamp_op("date_sub"), - ops.DateDiff: timestamp.timestamp_op("datediff"), - ops.TimestampAdd: timestamp.timestamp_op("date_add"), - ops.TimestampSub: timestamp.timestamp_op("date_sub"), - ops.TimestampDiff: timestamp.timestamp_diff, - ops.TimestampFromUNIX: timestamp.timestamp_from_unix, - ops.ExistsSubquery: exists_subquery, - # RowNumber, and rank functions starts with 0 in Ibis-land - ops.RowNumber: lambda *_: "row_number()", - ops.DenseRank: lambda *_: "dense_rank()", - ops.MinRank: lambda *_: "rank()", - ops.PercentRank: lambda *_: "percent_rank()", - ops.CumeDist: lambda *_: "cume_dist()", - ops.FirstValue: unary("first_value"), - ops.LastValue: unary("last_value"), - ops.Lag: window.shift_like("lag"), - ops.Lead: window.shift_like("lead"), - ops.Window: window.window, - ops.NTile: window.ntile, - ops.DayOfWeekIndex: timestamp.day_of_week_index, - ops.DayOfWeekName: timestamp.day_of_week_name, - ops.Strftime: timestamp.strftime, - ops.SortKey: sort_key, - ops.TypeOf: unary("typeof"), - **binary_infix_ops, -} diff --git a/ibis/backends/base/sql/registry/string.py b/ibis/backends/base/sql/registry/string.py deleted file mode 100644 index 5522f2db50ea..000000000000 --- a/ibis/backends/base/sql/registry/string.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations - -import ibis.expr.operations as ops -from ibis.backends.base.sql.registry import helpers - - -def substring(translator, op): - arg, start, length = op.args - arg_formatted = translator.translate(arg) - start_formatted = translator.translate(start) - - # Impala is 1-indexed - if length is None or isinstance(length, ops.Literal): - if lvalue := getattr(length, "value", None): - return f"substr({arg_formatted}, {start_formatted} + 1, {lvalue})" - else: - return f"substr({arg_formatted}, {start_formatted} + 1)" - else: - length_formatted = translator.translate(length) - return f"substr({arg_formatted}, {start_formatted} + 1, {length_formatted})" - - -def string_find(translator, op): - arg_formatted = translator.translate(op.arg) - substr_formatted = translator.translate(op.substr) - - if (start := op.start) is not None: - if not isinstance(start, ops.Literal): - start_fmt = translator.translate(start) - return f"locate({substr_formatted}, {arg_formatted}, {start_fmt} + 1) - 1" - elif sval := start.value: - return f"locate({substr_formatted}, {arg_formatted}, {sval + 1}) - 1" - else: - raise ValueError(f"invalid `start` value: {sval}") - else: - return f"locate({substr_formatted}, {arg_formatted}) - 1" - - -def find_in_set(translator, op): - arg_formatted = translator.translate(op.needle) - str_formatted = ",".join([x.value for x in op.values]) - return f"find_in_set({arg_formatted}, '{str_formatted}') - 1" - - -def string_join(translator, op): - arg, strings = op.args - return helpers.format_call(translator, "concat_ws", arg, *strings) - - -def string_like(translator, op): - arg = translator.translate(op.arg) - pattern = translator.translate(op.pattern) - return f"{arg} LIKE {pattern}" - - -def string_ilike(translator, op): - arg = translator.translate(op.arg) - pattern = translator.translate(op.pattern) - return f"upper({arg}) LIKE upper({pattern})" - - -def extract_url_field(extract): - if extract == "QUERY": - - def _op(translator, op): - arg, key = op.args - arg_formatted = translator.translate(arg) - - if key is None: - return f"parse_url({arg_formatted}, '{extract}')" - else: - key_fmt = translator.translate(key) - return f"parse_url({arg_formatted}, '{extract}', {key_fmt})" - - else: - - def _op(translator, op): - arg_formatted = translator.translate(op.arg) - return f"parse_url({arg_formatted}, '{extract}')" - - return _op - - -def startswith(translator, op): - arg_formatted = translator.translate(op.arg) - start_formatted = translator.translate(op.start) - - return f"{arg_formatted} like concat({start_formatted}, '%')" - - -def endswith(translator, op): - arg_formatted = translator.translate(op.arg) - end_formatted = translator.translate(op.end) - - return f"{arg_formatted} like concat('%', {end_formatted})" diff --git a/ibis/backends/base/sql/registry/timestamp.py b/ibis/backends/base/sql/registry/timestamp.py deleted file mode 100644 index b0eb60d3253a..000000000000 --- a/ibis/backends/base/sql/registry/timestamp.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -import ibis.common.exceptions as com -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -from ibis import util - - -def extract_field(sql_attr): - def extract_field_formatter(translator, op): - arg = translator.translate(op.args[0]) - - # This is pre-2.0 Impala-style, which did not used to support the - # SQL-99 format extract($FIELD from expr) - return f"extract({arg}, '{sql_attr}')" - - return extract_field_formatter - - -def extract_millisecond(translator, op): - arg = translator.translate(op.args[0]) - - # This is pre-2.0 Impala-style, which did not used to support the - # SQL-99 format extract($FIELD from expr) - return f"extract({arg}, 'millisecond') % 1000" - - -def extract_microsecond(translator, op): - arg = translator.translate(op.args[0]) - - # This is pre-2.0 Impala-style, which did not used to support the - # SQL-99 format extract($FIELD from expr) - return f"extract({arg}, 'microsecond') % 1000000" - - -def extract_epoch_seconds(t, op): - return f"unix_timestamp({t.translate(op.arg)})" - - -def truncate(translator, op): - base_unit_names = { - "Y": "Y", - "Q": "Q", - "M": "MONTH", - "W": "W", - "D": "J", - "h": "HH", - "m": "MI", - } - arg, unit = op.args - - arg_formatted = translator.translate(arg) - try: - unit = base_unit_names[unit.short] - except KeyError: - raise com.UnsupportedOperationError( - f"{unit!r} unit is not supported in timestamp truncate" - ) - - return f"trunc({arg_formatted}, '{unit}')" - - -def interval_from_integer(translator, op): - # interval cannot be selected from impala - arg = translator.translate(op.arg) - return f"INTERVAL {arg} {op.dtype.resolution.upper()}" - - -def timestamp_op(func): - def _formatter(translator, op): - formatted_left = translator.translate(op.left) - formatted_right = translator.translate(op.right) - - left_dtype = op.left.dtype - if left_dtype.is_timestamp() or left_dtype.is_date(): - formatted_left = f"cast({formatted_left} as timestamp)" - - right_dtype = op.right.dtype - if right_dtype.is_timestamp() or right_dtype.is_date(): - formatted_right = f"cast({formatted_right} as timestamp)" - - return f"{func}({formatted_left}, {formatted_right})" - - return _formatter - - -def timestamp_diff(translator, op): - return "unix_timestamp({}) - unix_timestamp({})".format( - translator.translate(op.left), translator.translate(op.right) - ) - - -def _from_unixtime(translator, expr): - arg = translator.translate(expr) - return f'from_unixtime({arg}, "yyyy-MM-dd HH:mm:ss")' - - -def timestamp_from_unix(translator, op): - val, unit = op.args - val = util.convert_unit(val, unit.short, "s").to_expr().cast("int32").op() - arg = _from_unixtime(translator, val) - return f"CAST({arg} AS timestamp)" - - -def day_of_week_index(t, op): - return f"pmod(dayofweek({t.translate(op.arg)}) - 2, 7)" - - -def strftime(t, op): - import sqlglot as sg - - hive_dialect = sg.dialects.hive.Hive - if (time_mapping := getattr(hive_dialect, "TIME_MAPPING", None)) is None: - time_mapping = hive_dialect.time_mapping - reverse_hive_mapping = {v: k for k, v in time_mapping.items()} - format_str = sg.time.format_time(op.format_str.value, reverse_hive_mapping) - targ = t.translate(ops.Cast(op.arg, to=dt.string)) - return f"from_unixtime(unix_timestamp({targ}), {format_str!r})" - - -def day_of_week_name(t, op): - return f"dayname({t.translate(op.arg)})" diff --git a/ibis/backends/base/sql/registry/window.py b/ibis/backends/base/sql/registry/window.py deleted file mode 100644 index eb711a9d0e68..000000000000 --- a/ibis/backends/base/sql/registry/window.py +++ /dev/null @@ -1,157 +0,0 @@ -from __future__ import annotations - -import ibis.common.exceptions as com -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops - -_map_interval_to_microseconds = { - "W": 604800000000, - "D": 86400000000, - "h": 3600000000, - "m": 60000000, - "s": 1000000, - "ms": 1000, - "us": 1, -} - - -def interval_boundary_to_integer(boundary): - if boundary is None: - return None - elif boundary.dtype.is_numeric(): - return boundary - - value = boundary.value - try: - multiplier = _map_interval_to_microseconds[value.dtype.unit.short] - except KeyError: - raise com.IbisInputError(f"Unsupported interval unit: {value.dtype.unit}") - - if isinstance(value, ops.Literal): - value = ops.Literal(value.value * multiplier, dt.int64) - else: - left = ops.Cast(value, to=dt.int64) - value = ops.Multiply(left, multiplier) - - return boundary.copy(value=value) - - -def time_range_to_range_window(frame): - # Check that ORDER BY column is a single time column: - if len(frame.order_by) > 1: - raise com.IbisInputError( - f"Expected 1 order-by variable, got {len(frame.order_by)}" - ) - - order_by = frame.order_by[0] - order_by = order_by.copy(expr=ops.Cast(order_by.expr, dt.int64)) - start = interval_boundary_to_integer(frame.start) - end = interval_boundary_to_integer(frame.end) - - return frame.copy(order_by=(order_by,), start=start, end=end) - - -def format_window_boundary(translator, boundary): - if isinstance(boundary.value, ops.Literal) and boundary.value.value == 0: - return "CURRENT ROW" - - value = translator.translate(boundary.value) - direction = "PRECEDING" if boundary.preceding else "FOLLOWING" - - return f"{value} {direction}" - - -def format_window_frame(translator, func, frame): - components = [] - - if frame.group_by: - partition_args = ", ".join(map(translator.translate, frame.group_by)) - components.append(f"PARTITION BY {partition_args}") - - if frame.order_by: - order_args = ", ".join(map(translator.translate, frame.order_by)) - components.append(f"ORDER BY {order_args}") - - if frame.start is None and frame.end is None: - # no-op, default is full sample - pass - elif not isinstance(func, translator._forbids_frame_clause): - if frame.start is None: - start = "UNBOUNDED PRECEDING" - else: - start = format_window_boundary(translator, frame.start) - - if frame.end is None: - end = "UNBOUNDED FOLLOWING" - else: - end = format_window_boundary(translator, frame.end) - - frame = f"{frame.how.upper()} BETWEEN {start} AND {end}" - components.append(frame) - - return "OVER ({})".format(" ".join(components)) - - -def window(translator, op): - _unsupported_reductions = translator._unsupported_reductions - - func = op.func.__window_op__ - - if isinstance(func, _unsupported_reductions): - raise com.UnsupportedOperationError( - f"{type(func)} is not supported in window functions" - ) - - # Some analytic functions need to have the expression of interest in - # the ORDER BY part of the window clause - frame = op.frame - if isinstance(func, translator._require_order_by) and not frame.order_by: - frame = frame.copy(order_by=(func.arg,)) - - # Time ranges need to be converted to microseconds. - if isinstance(frame, ops.RangeWindowFrame): - if any(c.dtype.is_temporal() for c in frame.order_by): - frame = time_range_to_range_window(frame) - elif isinstance(frame, ops.RowsWindowFrame): - if frame.max_lookback is not None: - raise NotImplementedError( - "Rows with max lookback is not implemented for SQL-based backends." - ) - - window_formatted = format_window_frame(translator, func, frame) - - arg_formatted = translator.translate(func.__window_op__) - result = f"{arg_formatted} {window_formatted}" - - if isinstance(func, (ops.RankBase, ops.NTile)): - return f"({result} - 1)" - else: - return result - - -def shift_like(name): - def formatter(translator, op): - arg, offset, default = op.args - - arg_formatted = translator.translate(arg) - - if default is not None: - if offset is None: - offset_formatted = "1" - else: - offset_formatted = translator.translate(offset) - - default_formatted = translator.translate(default) - - return f"{name}({arg_formatted}, {offset_formatted}, {default_formatted})" - elif offset is not None: - offset_formatted = translator.translate(offset) - return f"{name}({arg_formatted}, {offset_formatted})" - else: - return f"{name}({arg_formatted})" - - return formatter - - -def ntile(translator, op): - return f"ntile({translator.translate(op.buckets)})" diff --git a/ibis/backends/base/sqlglot/ddl.py b/ibis/backends/base/sqlglot/ddl.py new file mode 100644 index 000000000000..863b79ef6b87 --- /dev/null +++ b/ibis/backends/base/sqlglot/ddl.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import re +from abc import ABC, abstractmethod + +import sqlglot as sg + +import ibis.expr.datatypes as dt + +fully_qualified_re = re.compile(r"(.*)\.(?:`(.*)`|(.*))") + + +# TODO(kszucs): the following classes should produce SQLGlot expressions + + +def is_fully_qualified(x): + return bool(fully_qualified_re.search(x)) + + +def _is_quoted(x): + regex = re.compile(r"(?:`(.*)`|(.*))") + quoted, _ = regex.match(x).groups() + return quoted is not None + + +class Base(ABC): + @property + @abstractmethod + def dialect(self): + ... + + @abstractmethod + def compile(self): + ... + + def quote(self, ident): + return sg.to_identifier(ident, quoted=True).sql(dialect=self.dialect) + + def scoped_name( + self, obj_name: str, database: str | None = None, catalog: str | None = None + ) -> str: + if is_fully_qualified(obj_name): + return obj_name + if _is_quoted(obj_name): + obj_name = obj_name[1:-1] + return sg.table(obj_name, db=database, catalog=catalog, quoted=True).sql( + dialect=self.dialect + ) + + @abstractmethod + def format_dtype(self, dtype): + ... + + def format_schema(self, schema): + elements = [ + f"{self.quote(name)} {self.format_dtype(t)}" + for name, t in zip(schema.names, schema.types) + ] + return "({})".format(",\n ".join(elements)) + + def format_partition(self, partition, partition_schema): + def _format_partition_kv(k, v, type): + if type == dt.string: + value_formatted = f'"{v}"' + else: + value_formatted = str(v) + + return f"{k}={value_formatted}" + + tokens = [] + if isinstance(partition, dict): + for name in partition_schema: + if name in partition: + tok = _format_partition_kv( + name, partition[name], partition_schema[name] + ) + else: + # dynamic partitioning + tok = name + tokens.append(tok) + else: + for name, value in zip(partition_schema, partition): + tok = _format_partition_kv(name, value, partition_schema[name]) + tokens.append(tok) + + return "PARTITION ({})".format(", ".join(tokens)) + + +class DML(Base): + pass + + +class DDL(Base): + pass + + +class CreateDDL(DDL): + def _if_exists(self): + return "IF NOT EXISTS " if self.can_exist else "" + + +class DropObject(DDL): + def __init__(self, must_exist=True): + self.must_exist = must_exist + + def compile(self): + if_exists = "" if self.must_exist else "IF EXISTS " + object_name = self._object_name() + return f"DROP {self._object_type} {if_exists}{object_name}" + + +class DropFunction(DropObject): + def __init__(self, name, inputs, must_exist=True, aggregate=False, database=None): + super().__init__(must_exist=must_exist) + self.name = name + self.inputs = tuple(map(dt.dtype, inputs)) + self.must_exist = must_exist + self.aggregate = aggregate + self.database = database + + def _object_name(self): + return self.name + + def compile(self): + tokens = ["DROP"] + if self.aggregate: + tokens.append("AGGREGATE") + tokens.append("FUNCTION") + if not self.must_exist: + tokens.append("IF EXISTS") + + tokens.append(self._impala_signature()) + return " ".join(tokens) diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index e052d3d6ad53..1d631ddca52d 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -15,7 +15,7 @@ from ibis.backends.flink.compiler import FlinkCompiler from ibis.backends.flink.ddl import ( CreateDatabase, - CreateTableFromConnector, + CreateTableWithSchema, DropDatabase, DropTable, DropView, @@ -456,7 +456,7 @@ def create_table( # TODO (mehmet): Given that we rely on default catalog if one is not specified, # is there any point to support temporary tables? - statement = CreateTableFromConnector( + statement = CreateTableWithSchema( table_name=name, schema=schema, tbl_properties=tbl_properties, diff --git a/ibis/backends/flink/ddl.py b/ibis/backends/flink/ddl.py index 13e90e6eb75e..b1cbd1da0db1 100644 --- a/ibis/backends/flink/ddl.py +++ b/ibis/backends/flink/ddl.py @@ -6,20 +6,8 @@ import ibis.common.exceptions as exc import ibis.expr.schema as sch -from ibis.backends.base.sql.ddl import ( - CreateTable, - CreateTableWithSchema, - DropObject, - InsertSelect, - RenameTable, - _CreateDDL, - _format_properties, - _is_quoted, - format_partition, - is_fully_qualified, -) -from ibis.backends.base.sql.registry import quote_identifier from ibis.backends.base.sqlglot.datatypes import FlinkType +from ibis.backends.base.sqlglot.ddl import DDL, DML, CreateDDL, DropObject from ibis.util import promote_list if TYPE_CHECKING: @@ -28,107 +16,90 @@ from ibis.expr.api import Watermark -def format_schema(schema: sch.Schema): - elements = [ - _format_schema_element(name, t) for name, t in zip(schema.names, schema.types) - ] +class FlinkBase: + dialect = "hive" - return "({})".format(",\n ".join(elements)) - - -def _format_schema_element(name, t): - return f"{quote_identifier(name, force=True)} {type_to_flink_sql_string(t)}" - - -def type_to_flink_sql_string(tval): - sql_string = FlinkType.from_ibis(tval) - if tval.is_timestamp(): - return f"TIMESTAMP({tval.scale})" if tval.scale is not None else "TIMESTAMP" - else: - return sql_string.sql("flink") + " NOT NULL" * (not tval.nullable) - - -def _format_watermark_strategy(watermark: Watermark) -> str: - from ibis.backends.flink.utils import translate_literal - - if watermark.allowed_delay is None: - return watermark.time_col - return f"{watermark.time_col} - {translate_literal(watermark.allowed_delay.op())}" + def format_dtype(self, dtype): + sql_string = FlinkType.from_ibis(dtype) + if dtype.is_timestamp(): + return ( + f"TIMESTAMP({dtype.scale})" if dtype.scale is not None else "TIMESTAMP" + ) + else: + return sql_string.sql("flink") + " NOT NULL" * (not dtype.nullable) + def format_properties(self, props): + tokens = [] + for k, v in sorted(props.items()): + tokens.append(f" '{k}'='{v}'") + return "(\n{}\n)".format(",\n".join(tokens)) -def format_schema_with_watermark( - schema: sch.Schema, - watermark: Watermark | None = None, - primary_keys: Sequence[str] | None = None, -) -> str: - elements = [ - _format_schema_element(name, t) for name, t in zip(schema.names, schema.types) - ] + def format_watermark_strategy(self, watermark: Watermark) -> str: + from ibis.backends.flink.utils import translate_literal - if watermark is not None: - elements.append( - f"WATERMARK FOR {watermark.time_col} AS {_format_watermark_strategy(watermark)}" + if watermark.allowed_delay is None: + return watermark.time_col + return ( + f"{watermark.time_col} - {translate_literal(watermark.allowed_delay.op())}" ) - if primary_keys is not None and primary_keys: - # Note (mehmet): Currently supports "NOT ENFORCED" only. For the reason - # of this choice, the following quote from Flink docs is self-explanatory: - # "SQL standard specifies that a constraint can either be ENFORCED or - # NOT ENFORCED. This controls if the constraint checks are performed on - # the incoming/outgoing data. Flink does not own the data therefore the - # only mode we want to support is the NOT ENFORCED mode. It is up to the - # user to ensure that the query enforces key integrity." - # Ref: https://nightlies.apache.org/flink/flink-docs-release-1.18/docs/dev/table/sql/create/#primary-key - comma_separated_keys = ", ".join(f"`{key}`" for key in primary_keys) - elements.append(f"PRIMARY KEY ({comma_separated_keys}) NOT ENFORCED") - - return "({})".format(",\n ".join(elements)) + def format_schema_with_watermark( + self, + schema: sch.Schema, + watermark: Watermark | None = None, + primary_keys: Sequence[str] | None = None, + ) -> str: + elements = [ + f"{self.quote(name)} {self.format_dtype(t)}" + for name, t in zip(schema.names, schema.types) + ] + + if watermark is not None: + elements.append( + f"WATERMARK FOR {watermark.time_col} AS {self.format_watermark_strategy(watermark)}" + ) + if primary_keys is not None and primary_keys: + # Note (mehmet): Currently supports "NOT ENFORCED" only. For the reason + # of this choice, the following quote from Flink docs is self-explanatory: + # "SQL standard specifies that a constraint can either be ENFORCED or + # NOT ENFORCED. This controls if the constraint checks are performed on + # the incoming/outgoing data. Flink does not own the data therefore the + # only mode we want to support is the NOT ENFORCED mode. It is up to the + # user to ensure that the query enforces key integrity." + # Ref: https://nightlies.apache.org/flink/flink-docs-release-1.18/docs/dev/table/sql/create/#primary-key + comma_separated_keys = ", ".join(f"`{key}`" for key in primary_keys) + elements.append(f"PRIMARY KEY ({comma_separated_keys}) NOT ENFORCED") -class _CatalogAwareBaseQualifiedSQLStatement: - def _get_scoped_name( - self, obj_name: str, database: str | None = None, catalog: str | None = None - ) -> str: - if is_fully_qualified(obj_name): - return obj_name - if _is_quoted(obj_name): - obj_name = obj_name[1:-1] - return sg.table(obj_name, db=database, catalog=catalog, quoted=True).sql( - dialect="hive" - ) + return "({})".format(",\n ".join(elements)) -class CreateTableFromConnector( - _CatalogAwareBaseQualifiedSQLStatement, CreateTableWithSchema -): +class CreateTableWithSchema(FlinkBase, CreateDDL): def __init__( self, table_name: str, schema: sch.Schema, - tbl_properties: dict, - watermark: Watermark | None = None, + database=None, + catalog=None, + can_exist=False, + external=False, + partition=None, primary_key: str | Sequence[str] | None = None, - database: str | None = None, - catalog: str | None = None, + tbl_properties=None, temporary: bool = False, - **kwargs, + watermark: Watermark | None = None, ): - super().__init__( - table_name=table_name, - database=database, - schema=schema, - table_format=None, - format=None, - path=None, - tbl_properties=tbl_properties, - **kwargs, - ) + self.can_exist = can_exist self.catalog = catalog + self.database = database + self.partition = partition + self.primary_keys = promote_list(primary_key) + self.schema = schema + self.table_name = table_name + self.tbl_properties = tbl_properties self.temporary = temporary self.watermark = watermark - self.primary_keys = promote_list(primary_key) - # Check if `primary_keys` is a subset of the columns in `schema`. if self.primary_keys and not set(self.primary_keys) <= set(schema.names): raise exc.IbisError( @@ -137,12 +108,6 @@ def __init__( f"\t schema.names= {schema.names}" ) - def _storage(self) -> str: - return f"STORED AS {self.format}" if self.format else None - - def _format_tbl_properties(self) -> str: - return f"WITH {_format_properties(self.tbl_properties)}" - @property def _prefix(self) -> str: # `TEMPORARY` is not documented in Flink's documentation @@ -150,9 +115,7 @@ def _prefix(self) -> str: return f"CREATE{modifier} TABLE" def _create_line(self) -> str: - scoped_name = self._get_scoped_name( - self.table_name, self.database, self.catalog - ) + scoped_name = self.scoped_name(self.table_name, self.database, self.catalog) return f"{self._prefix} {self._if_exists()}{scoped_name}" @property @@ -172,19 +135,27 @@ def _pieces(self): } main_schema = sch.Schema(fields) - yield format_schema_with_watermark( + yield self.format_schema_with_watermark( main_schema, self.watermark, self.primary_keys ) - yield f"PARTITIONED BY {format_schema(part_schema)}" + yield f"PARTITIONED BY {self.format_schema(part_schema)}" else: - yield format_schema_with_watermark( + yield self.format_schema_with_watermark( self.schema, self.watermark, self.primary_keys ) - yield self._format_tbl_properties() + yield f"WITH {self.format_properties(self.tbl_properties)}" + @property + def pieces(self): + yield self._create_line() + yield from filter(None, self._pieces) -class CreateView(_CatalogAwareBaseQualifiedSQLStatement, CreateTable): + def compile(self): + return "\n".join(self.pieces) + + +class CreateView(FlinkBase, CreateDDL): def __init__( self, name: str, @@ -212,7 +183,7 @@ def _prefix(self): return "CREATE VIEW" def _create_line(self): - scoped_name = self._get_scoped_name(self.name, self.database, self.catalog) + scoped_name = self.scoped_name(self.name, self.database, self.catalog) return f"{self._prefix} {self._if_exists()}{scoped_name}" @property @@ -220,8 +191,11 @@ def pieces(self): yield self._create_line() yield f"AS {self.query_expression}" + def compile(self): + return "\n".join(self.pieces) + -class DropTable(_CatalogAwareBaseQualifiedSQLStatement, DropObject): +class DropTable(FlinkBase, DropObject): _object_type = "TABLE" def __init__( @@ -239,7 +213,7 @@ def __init__( self.temporary = temporary def _object_name(self): - return self._get_scoped_name(self.table_name, self.database, self.catalog) + return self.scoped_name(self.table_name, self.database, self.catalog) def compile(self): temporary = "TEMPORARY " if self.temporary else "" @@ -268,36 +242,30 @@ def __init__( ) -class RenameTable(RenameTable): - def __init__( - self, - old_name: str, - new_name: str, - old_database: str | None = None, - new_database: str | None = None, - must_exist: bool = True, - ): - super().__init__( - old_name=old_name, - new_name=new_name, - old_database=old_database, - new_database=new_database, - ) +class RenameTable(FlinkBase, DDL): + def __init__(self, old_name: str, new_name: str, must_exist: bool = True): + self.old_name = old_name + self.new_name = new_name self.must_exist = must_exist def compile(self): if_exists = "" if self.must_exist else "IF EXISTS" - return f"ALTER TABLE {if_exists} {self._old} RENAME TO {self._new}" + return f"ALTER TABLE {if_exists} {self.old_name} RENAME TO {self.new_name}" class _DatabaseObject: def _object_name(self): - scoped_name = f"{quote_identifier(self.catalog)}." if self.catalog else "" - scoped_name += quote_identifier(self.name) - return scoped_name + name = sg.to_identifier(self.name, quoted=True).sql(dialect=self.dialect) + if self.catalog: + catalog = sg.to_identifier(self.catalog, quoted=True).sql( + dialect=self.dialect + ) + return f"{catalog}.{name}" + else: + return name -class CreateDatabase(_DatabaseObject, _CreateDDL): +class CreateDatabase(FlinkBase, _DatabaseObject, CreateDDL): def __init__( self, name: str, @@ -313,7 +281,7 @@ def __init__( def _format_db_properties(self) -> str: return ( - f"WITH {_format_properties(self.db_properties)}" + f"WITH {self.format_properties(self.db_properties)}" if self.db_properties else "" ) @@ -325,7 +293,7 @@ def compile(self): return f"{create_line}\n{self._format_db_properties()}" -class DropDatabase(_DatabaseObject, DropObject): +class DropDatabase(FlinkBase, _DatabaseObject, DropObject): _object_type = "DATABASE" def __init__(self, name: str, catalog: str | None = None, must_exist: bool = True): @@ -334,7 +302,7 @@ def __init__(self, name: str, catalog: str | None = None, must_exist: bool = Tru self.catalog = catalog -class InsertSelect(_CatalogAwareBaseQualifiedSQLStatement, InsertSelect): +class InsertSelect(FlinkBase, DML): def __init__( self, table_name, @@ -345,10 +313,13 @@ def __init__( partition_schema=None, overwrite=False, ): - super().__init__( - table_name, select_expr, database, partition, partition_schema, overwrite - ) + self.table_name = table_name + self.database = database self.catalog = catalog + self.select = select_expr + self.partition = partition + self.partition_schema = partition_schema + self.overwrite = overwrite def compile(self): if self.overwrite: @@ -357,13 +328,11 @@ def compile(self): cmd = "INSERT INTO" if self.partition is not None: - part = format_partition(self.partition, self.partition_schema) + part = self.format_partition(self.partition, self.partition_schema) partition = f" {part} " else: partition = "" select_query = self.select - scoped_name = self._get_scoped_name( - self.table_name, self.database, self.catalog - ) + scoped_name = self.scoped_name(self.table_name, self.database, self.catalog) return f"{cmd} {scoped_name}{partition}\n{select_query}" diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index 9be53bd1532e..855bd65f7f3f 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -20,7 +20,11 @@ import ibis.expr.schema as sch import ibis.expr.types as ir from ibis import util -from ibis.backends.base.sql.ddl import ( +from ibis.backends.base.sqlglot import SQLGlotBackend +from ibis.backends.impala import ddl, udf +from ibis.backends.impala.client import ImpalaTable +from ibis.backends.impala.compiler import ImpalaCompiler +from ibis.backends.impala.ddl import ( CTAS, CreateDatabase, CreateTableWithSchema, @@ -31,10 +35,6 @@ RenameTable, TruncateTable, ) -from ibis.backends.base.sqlglot import SQLGlotBackend -from ibis.backends.impala import ddl, udf -from ibis.backends.impala.client import ImpalaTable -from ibis.backends.impala.compiler import ImpalaCompiler from ibis.backends.impala.udf import ( aggregate_function, scalar_function, @@ -51,7 +51,6 @@ import pyarrow as pa import ibis.expr.operations as ops - from ibis.backends.base.sql.compiler import DDL, DML __all__ = ( @@ -264,7 +263,7 @@ def _fetch_from_cursor(self, cursor, schema): return PandasData.convert_table(results, schema) @contextlib.contextmanager - def _safe_raw_sql(self, query: str | DDL | DML): + def _safe_raw_sql(self, query: str): if not isinstance(query, str): try: query = query.sql(dialect=self.dialect) diff --git a/ibis/backends/impala/client.py b/ibis/backends/impala/client.py index 66f54be14deb..78be3f28d8e7 100644 --- a/ibis/backends/impala/client.py +++ b/ibis/backends/impala/client.py @@ -8,8 +8,8 @@ import ibis.common.exceptions as com import ibis.expr.schema as sch import ibis.expr.types as ir -from ibis.backends.base.sql.ddl import AlterTable, InsertSelect from ibis.backends.impala import ddl +from ibis.backends.impala.ddl import AlterTable, InsertSelect if TYPE_CHECKING: import pandas as pd diff --git a/ibis/backends/impala/ddl.py b/ibis/backends/impala/ddl.py index 91f03e48c564..495ed48408eb 100644 --- a/ibis/backends/impala/ddl.py +++ b/ibis/backends/impala/ddl.py @@ -1,31 +1,305 @@ from __future__ import annotations -# Copyright 2014 Cloudera Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import json -from ibis.backends.base.sql.ddl import ( - AlterTable, - BaseDDL, - CreateTable, - CreateTableWithSchema, - DropFunction, - format_partition, - format_schema, - format_tblproperties, -) -from ibis.backends.base.sql.registry import type_to_sql_string +import sqlglot as sg + +import ibis.expr.schema as sch +from ibis.backends.base.sqlglot.datatypes import ImpalaType +from ibis.backends.base.sqlglot.ddl import DDL, DML, CreateDDL, DropFunction, DropObject + + +class ImpalaBase: + dialect = "hive" + + def sanitize_format(self, format): + _format_aliases = {"TEXT": "TEXTFILE"} + + if format is None: + return None + format = format.upper() + format = _format_aliases.get(format, format) + if format not in ("PARQUET", "AVRO", "TEXTFILE"): + raise ValueError(f"Invalid format: {format!r}") + + return format + + def format_dtype(self, dtype): + return ImpalaType.to_string(dtype) + + def format_properties(self, props): + tokens = [] + for k, v in sorted(props.items()): + tokens.append(f" '{k}'='{v}'") + return "(\n{}\n)".format(",\n".join(tokens)) + + def format_tblproperties(self, props): + formatted_props = self.format_properties(props) + return f"TBLPROPERTIES {formatted_props}" + + def format_serdeproperties(self, props): + formatted_props = self.format_properties(props) + return f"SERDEPROPERTIES {formatted_props}" + + +class CreateDatabase(ImpalaBase, CreateDDL): + def __init__(self, name, path=None, can_exist=False): + self.name = name + self.path = path + self.can_exist = can_exist + + def compile(self): + name = self.quote(self.name) + + create_decl = "CREATE DATABASE" + create_line = f"{create_decl} {self._if_exists()}{name}" + if self.path is not None: + create_line += f"\nLOCATION '{self.path}'" + + return create_line + + +class DropDatabase(ImpalaBase, DropObject): + _object_type = "DATABASE" + + def __init__(self, name, must_exist=True): + super().__init__(must_exist=must_exist) + self.name = name + + def _object_name(self): + return self.name + + +class CreateTable(ImpalaBase, CreateDDL): + def __init__( + self, + table_name, + database=None, + external=False, + format="parquet", + can_exist=False, + partition=None, + path=None, + tbl_properties=None, + ): + self.table_name = table_name + self.database = database + self.partition = partition + self.path = path + self.external = external + self.can_exist = can_exist + self.format = self.sanitize_format(format) + self.tbl_properties = tbl_properties + + @property + def _prefix(self): + if self.external: + return "CREATE EXTERNAL TABLE" + else: + return "CREATE TABLE" + + def _create_line(self): + scoped_name = self.scoped_name(self.table_name, self.database) + return f"{self._prefix} {self._if_exists()}{scoped_name}" + + def _location(self): + return f"LOCATION '{self.path}'" if self.path else None + + def _storage(self): + # By the time we're here, we have a valid format + return f"STORED AS {self.format}" + + @property + def pieces(self): + yield self._create_line() + yield from filter(None, self._pieces) + + def compile(self): + return "\n".join(self.pieces) + + +class CreateTableWithSchema(CreateTable): + def __init__(self, table_name, schema, table_format=None, **kwargs): + super().__init__(table_name, **kwargs) + self.schema = schema + self.table_format = table_format + + @property + def _pieces(self): + if self.partition is not None: + main_schema = self.schema + part_schema = self.partition + if not isinstance(part_schema, sch.Schema): + part_fields = {name: self.schema[name] for name in part_schema} + part_schema = sch.Schema(part_fields) + + to_delete = {name for name in self.partition if name in self.schema} + fields = { + name: dtype + for name, dtype in main_schema.items() + if name not in to_delete + } + main_schema = sch.Schema(fields) + + yield self.format_schema(main_schema) + yield f"PARTITIONED BY {self.format_schema(part_schema)}" + else: + yield self.format_schema(self.schema) + + if self.table_format is not None: + yield "\n".join(self.table_format.to_ddl()) + else: + yield self._storage() + + yield self._location() + + +class AlterTable(ImpalaBase, DDL): + def __init__( + self, + table, + location=None, + format=None, + tbl_properties=None, + serde_properties=None, + ): + self.table = table + self.location = location + self.format = self.sanitize_format(format) + self.tbl_properties = tbl_properties + self.serde_properties = serde_properties + + def _wrap_command(self, cmd): + return f"ALTER TABLE {cmd}" + + def _format_properties(self, prefix=""): + tokens = [] + + if self.location is not None: + tokens.append(f"LOCATION '{self.location}'") + + if self.format is not None: + tokens.append(f"FILEFORMAT {self.format}") + + if self.tbl_properties is not None: + tokens.append(self.format_tblproperties(self.tbl_properties)) + + if self.serde_properties is not None: + tokens.append(self.format_serdeproperties(self.serde_properties)) + + if len(tokens) > 0: + return "\n{}{}".format(prefix, "\n".join(tokens)) + else: + return "" + + def compile(self): + props = self._format_properties() + action = f"{self.table} SET {props}" + return self._wrap_command(action) + + +class RenameTable(AlterTable): + def __init__( + self, + old_name: str, + new_name: str, + old_database: str | None = None, + new_database: str | None = None, + ): + self._old = sg.table(old_name, db=old_database, quoted=True).sql( + dialect=self.dialect + ) + self._new = sg.table(new_name, db=new_database, quoted=True).sql( + dialect=self.dialect + ) + + def compile(self): + return self._wrap_command(f"{self._old} RENAME TO {self._new}") + + +class DropTable(ImpalaBase, DropObject): + _object_type = "TABLE" + + def __init__(self, table_name, database=None, must_exist=True): + super().__init__(must_exist=must_exist) + self.table_name = table_name + self.database = database + + def _object_name(self): + return self.scoped_name(self.table_name, self.database) + + +class TruncateTable(ImpalaBase, DDL): + _object_type = "TABLE" + + def __init__(self, table_name, database=None): + self.table_name = table_name + self.database = database + + def compile(self): + name = self.scoped_name(self.table_name, self.database) + return f"TRUNCATE TABLE {name}" + + +class DropView(DropTable): + _object_type = "VIEW" + + +class CTAS(CreateTable): + """Create Table As Select.""" + + def __init__( + self, + table_name, + select, + database=None, + external=False, + format="parquet", + can_exist=False, + path=None, + partition=None, + ): + super().__init__( + table_name, + database=database, + external=external, + format=format, + can_exist=can_exist, + path=path, + partition=partition, + ) + self.select = select + + @property + def _pieces(self): + yield self._partitioned_by() + yield self._storage() + yield self._location() + yield "AS" + yield self.select + + def _partitioned_by(self): + if self.partition is not None: + return "PARTITIONED BY ({})".format( + ", ".join(self.quote(expr.get_name()) for expr in self.partition) + ) + return None + + +class CreateView(CTAS): + """Create a view.""" + + def __init__(self, table_name, select, database=None, can_exist=False): + super().__init__(table_name, select, database=database, can_exist=can_exist) + + @property + def _pieces(self): + yield "AS" + yield self.select + + @property + def _prefix(self): + return "CREATE VIEW" class CreateTableParquet(CreateTable): @@ -57,7 +331,7 @@ def _pieces(self): elif self.example_table is not None: yield f"LIKE {self.example_table}" elif self.schema is not None: - yield format_schema(self.schema) + yield self.format_schema(self.schema) else: raise NotImplementedError @@ -65,7 +339,7 @@ def _pieces(self): yield self._location() -class DelimitedFormat: +class DelimitedFormat(ImpalaBase): def __init__( self, path, @@ -97,10 +371,10 @@ def to_ddl(self): if self.na_rep is not None: props = {"serialization.null.format": self.na_rep} - yield format_tblproperties(props) + yield self.format_tblproperties(props) -class AvroFormat: +class AvroFormat(ImpalaBase): def __init__(self, path, avro_schema): self.path = path self.avro_schema = avro_schema @@ -113,10 +387,10 @@ def to_ddl(self): schema = "\n".join(x.rstrip() for x in schema.splitlines()) props = {"avro.schema.literal": schema} - yield format_tblproperties(props) + yield self.format_tblproperties(props) -class ParquetFormat: +class ParquetFormat(ImpalaBase): def __init__(self, path): self.path = path @@ -158,7 +432,7 @@ def _pieces(self): yield "\n".join(self.table_format.to_ddl()) -class LoadData(BaseDDL): +class LoadData(ImpalaBase, DDL): """Generate DDL for LOAD DATA command. Cannot be cancelled @@ -186,11 +460,13 @@ def compile(self): overwrite = "OVERWRITE " if self.overwrite else "" if self.partition is not None: - partition = "\n" + format_partition(self.partition, self.partition_schema) + partition = "\n" + self.format_partition( + self.partition, self.partition_schema + ) else: partition = "" - scoped_name = self._get_scoped_name(self.table_name, self.database) + scoped_name = self.scoped_name(self.table_name, self.database) return "LOAD DATA INPATH '{}' {}INTO TABLE {}{}".format( self.path, overwrite, scoped_name, partition ) @@ -218,7 +494,7 @@ def __init__( self.partition_schema = partition_schema def _compile(self, cmd, property_prefix=""): - part = format_partition(self.partition, self.partition_schema) + part = self.format_partition(self.partition, self.partition_schema) if cmd: part = f"{cmd} {part}" @@ -228,6 +504,8 @@ def _compile(self, cmd, property_prefix=""): class AddPartition(PartitionProperties): + dialect = "hive" + def __init__(self, table, partition, partition_schema, location=None): super().__init__(table, partition, partition_schema, location=location) @@ -236,11 +514,15 @@ def compile(self): class AlterPartition(PartitionProperties): + dialect = "hive" + def compile(self): return self._compile("", "SET ") class DropPartition(PartitionProperties): + dialect = "hive" + def __init__(self, table, partition, partition_schema): super().__init__(table, partition, partition_schema) @@ -248,18 +530,18 @@ def compile(self): return self._compile("DROP") -class CacheTable(BaseDDL): +class CacheTable(ImpalaBase, DDL): def __init__(self, table_name, database=None, pool="default"): self.table_name = table_name self.database = database self.pool = pool def compile(self): - scoped_name = self._get_scoped_name(self.table_name, self.database) + scoped_name = self.scoped_name(self.table_name, self.database) return f"ALTER TABLE {scoped_name} SET CACHED IN '{self.pool}'" -class CreateFunction(BaseDDL): +class CreateFunction(ImpalaBase, DDL): _object_type = "FUNCTION" def __init__(self, func, name=None, database=None): @@ -268,9 +550,9 @@ def __init__(self, func, name=None, database=None): self.database = database def _impala_signature(self): - scoped_name = self._get_scoped_name(self.name, self.database) - input_sig = _impala_input_signature(self.func.inputs) - output_sig = type_to_sql_string(self.func.output) + scoped_name = self.scoped_name(self.name, self.database) + input_sig = ", ".join(map(self.format_dtype, self.func.inputs)) + output_sig = self.format_dtype(self.func.output) return f"{scoped_name}({input_sig}) returns {output_sig}" @@ -306,14 +588,14 @@ def compile(self): return f"{create_decl} {impala_sig} {joined_tokens}" -class DropFunction(DropFunction): +class DropFunction(ImpalaBase, DropFunction): def _impala_signature(self): - full_name = self._get_scoped_name(self.name, self.database) - input_sig = _impala_input_signature(self.inputs) + full_name = self.scoped_name(self.name, self.database) + input_sig = ", ".join(map(self.format_dtype, self.inputs)) return f"{full_name}({input_sig})" -class ListFunction(BaseDDL): +class ListFunction(ImpalaBase, DDL): def __init__(self, database, like=None, aggregate=False): self.database = database self.like = like @@ -329,6 +611,37 @@ def compile(self): return statement -def _impala_input_signature(inputs): - # TODO: varargs '{}...'.format(val) - return ", ".join(map(type_to_sql_string, inputs)) +class InsertSelect(ImpalaBase, DML): + def __init__( + self, + table_name, + select_expr, + database=None, + partition=None, + partition_schema=None, + overwrite=False, + ): + self.table_name = table_name + self.database = database + self.select = select_expr + + self.partition = partition + self.partition_schema = partition_schema + + self.overwrite = overwrite + + def compile(self): + if self.overwrite: + cmd = "INSERT OVERWRITE" + else: + cmd = "INSERT INTO" + + if self.partition is not None: + part = self.format_partition(self.partition, self.partition_schema) + partition = f" {part} " + else: + partition = "" + + select_query = self.select + scoped_name = self.scoped_name(self.table_name, self.database) + return f"{cmd} {scoped_name}{partition}\n{select_query}" diff --git a/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_delimited/out.sql b/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_delimited/out.sql index 74b26ff98c2a..a43eb6528713 100644 Binary files a/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_delimited/out.sql and b/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_delimited/out.sql differ diff --git a/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_parquet_with_schema/out.sql b/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_parquet_with_schema/out.sql index d5113e93e32c..50ceb3c8b5c4 100644 --- a/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_parquet_with_schema/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_parquet_with_schema/out.sql @@ -1,6 +1,6 @@ CREATE EXTERNAL TABLE IF NOT EXISTS `foo`.`new_table` -(`foo` string, - `bar` tinyint, - `baz` smallint) +(`foo` STRING, + `bar` TINYINT, + `baz` SMALLINT) STORED AS PARQUET LOCATION '/path/to/' \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_with_location_compile/out.sql b/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_with_location_compile/out.sql index e044b30cc1e6..86d38923ae17 100644 --- a/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_with_location_compile/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_ddl_compilation/test_create_table_with_location_compile/out.sql @@ -1,6 +1,6 @@ CREATE TABLE `foo`.`another_table` -(`foo` string, - `bar` tinyint, - `baz` smallint) +(`foo` STRING, + `bar` TINYINT, + `baz` SMALLINT) STORED AS PARQUET LOCATION '/path/to/table' \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/False/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/False/out.sql index ef59d670b5b0..bf781d5800db 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/False/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/False/out.sql @@ -1,4 +1,4 @@ -CREATE AGGREGATE FUNCTION `bar`.`test_name`(string, string) returns bigint location '/foo/bar.so' +CREATE AGGREGATE FUNCTION `bar`.`test_name`(STRING, STRING) returns BIGINT location '/foo/bar.so' init_fn='Init' update_fn='Update' merge_fn='Merge' diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/True/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/True/out.sql index a60787098563..56db248a1fc9 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/True/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_create_uda/True/out.sql @@ -1,4 +1,4 @@ -CREATE AGGREGATE FUNCTION `bar`.`test_name`(string, string) returns bigint location '/foo/bar.so' +CREATE AGGREGATE FUNCTION `bar`.`test_name`(STRING, STRING) returns BIGINT location '/foo/bar.so' init_fn='Init' update_fn='Update' merge_fn='Merge' diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf/out.sql index 45094d4661d2..3a4357055815 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf/out.sql @@ -1 +1 @@ -CREATE FUNCTION `test_name`(string, string) returns bigint location '/foo/bar.so' symbol='testFunc' \ No newline at end of file +CREATE FUNCTION `test_name`(STRING, STRING) returns BIGINT location '/foo/bar.so' symbol='testFunc' \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf_type_conversions/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf_type_conversions/out.sql index 544c0a4295b8..d7844869ca6b 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf_type_conversions/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_create_udf_type_conversions/out.sql @@ -1 +1 @@ -CREATE FUNCTION `test_name`(string, tinyint, smallint, int) returns bigint location '/foo/bar.so' symbol='testFunc' \ No newline at end of file +CREATE FUNCTION `test_name`(STRING, TINYINT, SMALLINT, INT) returns BIGINT location '/foo/bar.so' symbol='testFunc' \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_aggregate/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_aggregate/out.sql index e187efd69b54..e88182886606 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_aggregate/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_aggregate/out.sql @@ -1 +1 @@ -DROP AGGREGATE FUNCTION `test_name`(string, string) \ No newline at end of file +DROP AGGREGATE FUNCTION `test_name`(STRING, STRING) \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_db/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_db/out.sql index 6fd67c8afe29..b4388e852246 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_db/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_db/out.sql @@ -1 +1 @@ -DROP FUNCTION `test`.`test_name`(string, string) \ No newline at end of file +DROP FUNCTION `test`.`test_name`(STRING, STRING) \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_if_exists/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_if_exists/out.sql index 03483ec89f73..84de38e192a6 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_if_exists/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_if_exists/out.sql @@ -1 +1 @@ -DROP FUNCTION IF EXISTS `test_name`(string, string) \ No newline at end of file +DROP FUNCTION IF EXISTS `test_name`(STRING, STRING) \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_simple/out.sql b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_simple/out.sql index 5e58a5e64d81..e8a92719582d 100644 --- a/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_simple/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_udf/test_delete_udf_simple/out.sql @@ -1 +1 @@ -DROP FUNCTION `test_name`(string, string) \ No newline at end of file +DROP FUNCTION `test_name`(STRING, STRING) \ No newline at end of file diff --git a/ibis/backends/impala/tests/test_ddl_compilation.py b/ibis/backends/impala/tests/test_ddl_compilation.py index 4b96400e6e8f..929075d92aa6 100644 --- a/ibis/backends/impala/tests/test_ddl_compilation.py +++ b/ibis/backends/impala/tests/test_ddl_compilation.py @@ -3,13 +3,13 @@ import pytest import ibis -from ibis.backends.base.sql.ddl import ( +from ibis.backends.impala import ddl +from ibis.backends.impala.ddl import ( CTAS, CreateTableWithSchema, DropTable, InsertSelect, ) -from ibis.backends.impala import ddl @pytest.fixture diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 6728cb6633fa..62ffb72c0ccf 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -27,6 +27,7 @@ from ibis.backends.conftest import ALL_BACKENDS from ibis.backends.tests.errors import ( ExaQueryError, + ImpalaHiveServer2Error, OracleDatabaseError, PsycoPg2InternalError, PsycoPg2UndefinedObject, @@ -639,14 +640,9 @@ def test_list_databases(con): assert test_databases[con.name] <= result -@pytest.mark.notyet( - ["postgres", "snowflake"], - raises=TypeError, - reason="backend does not support unsigned integer types", -) @pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) @pytest.mark.notyet(["pyspark"], raises=com.IbisTypeError) -@pytest.mark.notyet(["bigquery", "impala"], raises=com.UnsupportedBackendType) +@pytest.mark.notyet(["bigquery"], raises=com.UnsupportedBackendType) @pytest.mark.notyet( ["postgres"], raises=PsycoPg2UndefinedObject, reason="no unsigned int types" ) @@ -657,6 +653,7 @@ def test_list_databases(con): @pytest.mark.notyet(["datafusion"], raises=Exception, reason="no unsigned int types") @pytest.mark.notyet(["druid"], raises=NotImplementedError) @pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError) +@pytest.mark.notyet(["impala"], raises=ImpalaHiveServer2Error) @pytest.mark.notyet( ["risingwave"], raises=PsycoPg2InternalError,