From 0cc694862154d0ae62b92151b8e08f2a0a8e34b9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 6 Aug 2022 10:16:43 -0400 Subject: [PATCH] feat(api): add `ibis.memtable` API for constructing in-memory table expressions --- ibis/backends/base/sql/alchemy/__init__.py | 12 +- .../base/sql/alchemy/query_builder.py | 6 +- .../base/sql/compiler/query_builder.py | 7 + .../base/sql/compiler/select_builder.py | 5 + ibis/backends/duckdb/compiler.py | 2 + ibis/backends/pandas/client.py | 25 ++ ibis/backends/pyspark/__init__.py | 5 +- ibis/backends/pyspark/compiler.py | 355 ++++++++++-------- ibis/backends/tests/test_client.py | 75 ++++ ibis/common/grounds.py | 19 +- ibis/expr/api.py | 107 +++++- ibis/expr/format.py | 33 ++ ibis/expr/operations/relations.py | 12 + ibis/expr/rules.py | 32 +- ibis/expr/types/relations.py | 7 +- ibis/tests/expr/test_table.py | 9 + 16 files changed, 518 insertions(+), 193 deletions(-) diff --git a/ibis/backends/base/sql/alchemy/__init__.py b/ibis/backends/base/sql/alchemy/__init__.py index 8ebcbecf51a3..f72648e56a86 100644 --- a/ibis/backends/base/sql/alchemy/__init__.py +++ b/ibis/backends/base/sql/alchemy/__init__.py @@ -9,6 +9,7 @@ import ibis import ibis.expr.datatypes as dt +import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir import ibis.util as util @@ -212,9 +213,14 @@ def create_table( with self.begin() as bind: t.create(bind=bind, checkfirst=force) if expr is not None: - bind.execute( - t.insert().from_select(list(expr.columns), expr.compile()) - ) + compiled = self.compile(expr) + insert = t.insert() + if isinstance(expr.op(), ops.InMemoryTable): + compiled = compiled.get_final_froms()[0] + sa_expr = insert.values(*compiled._data) + else: + sa_expr = insert.from_select(list(expr.columns), compiled) + bind.execute(sa_expr) def _columns_from_schema( self, name: str, schema: sch.Schema diff --git a/ibis/backends/base/sql/alchemy/query_builder.py b/ibis/backends/base/sql/alchemy/query_builder.py index ded2dfc68422..a3ce29506ca0 100644 --- a/ibis/backends/base/sql/alchemy/query_builder.py +++ b/ibis/backends/base/sql/alchemy/query_builder.py @@ -107,6 +107,10 @@ def _format_table(self, expr): ) backend = ref_op.child._find_backend() backend._create_temp_view(view=result, definition=definition) + elif isinstance(ref_op, ops.InMemoryTable): + columns = _schema_to_sqlalchemy_columns(ref_op.schema) + rows = list(ref_op.data.itertuples(index=False)) + result = sa.values(*columns).data(rows) else: # A subquery if ctx.is_extracted(ref_expr): @@ -194,7 +198,7 @@ def _compile_subqueries(self): def _compile_table_set(self): if self.table_set is not None: - helper = _AlchemyTableSetFormatter(self, self.table_set) + helper = self.table_set_formatter_class(self, self.table_set) result = helper.get_result() if isinstance(result, sql.selectable.Select): return result.subquery() diff --git a/ibis/backends/base/sql/compiler/query_builder.py b/ibis/backends/base/sql/compiler/query_builder.py index 0504cea2ebfd..a1dfa1b45bee 100644 --- a/ibis/backends/base/sql/compiler/query_builder.py +++ b/ibis/backends/base/sql/compiler/query_builder.py @@ -116,6 +116,13 @@ def _format_table(self, expr): raise com.RelationError(f'Table did not have a name: {expr!r}') result = self._quote_identifier(name) is_subquery = False + elif isinstance(ref_op, ops.InMemoryTable): + rows = ", ".join( + f"({', '.join(map(repr, col))})" + for col in ref_op.data.itertuples(index=False) + ) + result = f"(VALUES {rows})" + is_subquery = True else: # A subquery if ctx.is_extracted(ref_expr): diff --git a/ibis/backends/base/sql/compiler/select_builder.py b/ibis/backends/base/sql/compiler/select_builder.py index d143519efc33..13b65c923b03 100644 --- a/ibis/backends/base/sql/compiler/select_builder.py +++ b/ibis/backends/base/sql/compiler/select_builder.py @@ -513,6 +513,11 @@ def _collect_Selection(self, expr, toplevel=False): self.table_set = table self.filters = filters + def _collect_PandasInMemoryTable(self, expr, toplevel=False): + if toplevel: + self.select_set = [expr] + self.table_set = expr + def _convert_group_by(self, exprs): return list(range(len(exprs))) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index ea6a921272c2..b2fd2db28482 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlalchemy.ext.compiler import compiles import ibis.backends.base.sql.alchemy.datatypes as sat diff --git a/ibis/backends/pandas/client.py b/ibis/backends/pandas/client.py index 42c553f717ff..89bd1f87f206 100644 --- a/ibis/backends/pandas/client.py +++ b/ibis/backends/pandas/client.py @@ -9,8 +9,11 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.rules as rlz import ibis.expr.schema as sch +from ibis import util from ibis.backends.base import Database +from ibis.common.grounds import Immutable infer_pandas_dtype = pd.api.types.infer_dtype @@ -298,6 +301,28 @@ def convert_array_to_series(in_dtype, out_dtype, column): sch.Schema.to_pandas = ibis_schema_to_pandas # type: ignore +class DataFrameProxy(Immutable): + __slots__ = ('_df', '_hash') + + def __init__(self, df): + object.__setattr__(self, "_df", df) + object.__setattr__(self, "_hash", hash((type(df), id(df)))) + + def __getattr__(self, name): + return getattr(self._df, name) + + def __hash__(self): + return self._hash + + def __repr__(self): + df_repr = util.indent(repr(self._df), spaces=2) + return f"{self.__class__.__name__}:\n{df_repr}" + + +class PandasInMemoryTable(ops.InMemoryTable): + data = rlz.instance_of(DataFrameProxy) + + class PandasTable(ops.DatabaseTable): pass diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index c4e197917f46..91fb0b9b9367 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -191,7 +191,10 @@ def compile(self, expr, timecontext=None, params=None, *args, **kwargs): timecontext, ) return PySparkExprTranslator().translate( - expr, scope=scope, timecontext=timecontext + expr, + scope=scope, + timecontext=timecontext, + session=self._session, ) def execute( diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index f2f8277f0aa2..73b0dae8de75 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -6,6 +6,7 @@ import pandas as pd import pyspark import pyspark.sql.functions as F +import pyspark.sql.types as pt from pyspark.sql import Window from pyspark.sql.functions import PandasUDFType, pandas_udf @@ -15,6 +16,7 @@ import ibis.expr.types as ir import ibis.expr.types as types from ibis import interval +from ibis.backends.pandas.client import PandasInMemoryTable from ibis.backends.pandas.execution import execute from ibis.backends.pyspark.datatypes import ( ibis_array_dtype_to_spark_dtype, @@ -86,14 +88,14 @@ def translate(self, expr, scope, timecontext, **kwargs): @compiles(PySparkDatabaseTable) -def compile_datasource(t, expr, scope, timecontext): +def compile_datasource(t, expr, scope, timecontext, **_): op = expr.op() name, _, client = op.args return filter_by_time_context(client._session.table(name), timecontext) @compiles(ops.SQLQueryResult) -def compile_sql_query_result(t, expr, scope, timecontext, **kwargs): +def compile_sql_query_result(t, expr, scope, timecontext, **_): op = expr.op() query, _, client = op.args return client._session.sql(query) @@ -141,7 +143,7 @@ def compile_selection(t, expr, scope, timecontext, **kwargs): # in this case, we use the original timecontext if not adjusted_timecontext: adjusted_timecontext = timecontext - src_table = t.translate(op.table, scope, adjusted_timecontext) + src_table = t.translate(op.table, scope, adjusted_timecontext, **kwargs) col_in_selection_order = [] col_to_drop = [] @@ -150,7 +152,9 @@ def compile_selection(t, expr, scope, timecontext, **kwargs): if isinstance(selection, types.Table): col_in_selection_order.extend(selection.columns) elif isinstance(selection, types.DestructColumn): - struct_col = t.translate(selection, scope, adjusted_timecontext) + struct_col = t.translate( + selection, scope, adjusted_timecontext, **kwargs + ) # assign struct col and drop it later # This is a work around to ensure that the struct_col # is only executed once @@ -170,7 +174,7 @@ def compile_selection(t, expr, scope, timecontext, **kwargs): col_in_selection_order.append(selection.get_name()) else: col = t.translate( - selection, scope, adjusted_timecontext + selection, scope, adjusted_timecontext, **kwargs ).alias(selection.get_name()) col_in_selection_order.append(col) else: @@ -184,7 +188,7 @@ def compile_selection(t, expr, scope, timecontext, **kwargs): result_table = result_table.drop(*col_to_drop) for predicate in op.predicates: - col = t.translate(predicate, scope, timecontext) + col = t.translate(predicate, scope, timecontext, **kwargs) # Due to an upstream Spark issue (SPARK-33057) we cannot # directly use filter with a window operation. The workaround # here is to assign a temporary column for the filter predicate, @@ -196,7 +200,8 @@ def compile_selection(t, expr, scope, timecontext, **kwargs): if op.sort_keys: sort_cols = [ - t.translate(key, scope, timecontext) for key in op.sort_keys + t.translate(key, scope, timecontext, **kwargs) + for key in op.sort_keys ] result_table = result_table.sort(*sort_cols) @@ -208,7 +213,7 @@ def compile_selection(t, expr, scope, timecontext, **kwargs): @compiles(ops.SortKey) def compile_sort_key(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.expr, scope, timecontext) + col = t.translate(op.expr, scope, timecontext, **kwargs) if op.ascending: return col.asc() @@ -234,14 +239,14 @@ def wrapper(t, expr, *args, **kwargs): @compile_nan_as_null def compile_column(t, expr, scope, timecontext, **kwargs): op = expr.op() - table = t.translate(op.table, scope, timecontext) + table = t.translate(op.table, scope, timecontext, **kwargs) return table[op.name] @compiles(ops.SelfReference) def compile_self_reference(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.table, scope, timecontext) + return t.translate(op.table, scope, timecontext, **kwargs) @compiles(ops.Cast) @@ -262,7 +267,7 @@ def compile_cast(t, expr, scope, timecontext, **kwargs): else: cast_type = ibis_dtype_to_spark_dtype(op.to) - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return src_column.cast(cast_type) @@ -274,101 +279,101 @@ def compile_limit(t, expr, scope, timecontext, **kwargs): 'PySpark backend does not support non-zero offset is for ' 'limit operation. Got offset {}.'.format(op.offset) ) - df = t.translate(op.table, scope, timecontext) + df = t.translate(op.table, scope, timecontext, **kwargs) return df.limit(op.n) @compiles(ops.And) def compile_and(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) & t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) & t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.Or) def compile_or(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) | t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) | t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.Xor) def compile_xor(t, expr, scope, timecontext, **kwargs): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = t.translate(op.right, scope, timecontext) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = t.translate(op.right, scope, timecontext, **kwargs) return (left | right) & ~(left & right) @compiles(ops.Equals) def compile_equals(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) == t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) == t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.Not) def compile_not(t, expr, scope, timecontext, **kwargs): op = expr.op() - return ~t.translate(op.arg, scope, timecontext) + return ~t.translate(op.arg, scope, timecontext, **kwargs) @compiles(ops.NotEquals) def compile_not_equals(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) != t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) != t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.Greater) def compile_greater(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) > t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) > t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.GreaterEqual) def compile_greater_equal(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) >= t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) >= t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.Less) def compile_less(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) < t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) < t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.LessEqual) def compile_less_equal(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) <= t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) <= t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.Multiply) def compile_multiply(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) * t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) * t.translate( + op.right, scope, timecontext, **kwargs ) @compiles(ops.Subtract) def compile_subtract(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.left, scope, timecontext) - t.translate( - op.right, scope, timecontext + return t.translate(op.left, scope, timecontext, **kwargs) - t.translate( + op.right, scope, timecontext, **kwargs ) @@ -402,7 +407,7 @@ def compile_literal(t, expr, scope, timecontext, raw=False, **kwargs): def _compile_agg(t, agg_expr, scope, timecontext, *, context, **kwargs): - agg = t.translate(agg_expr, scope, timecontext, context=context) + agg = t.translate(agg_expr, scope, timecontext, context=context, **kwargs) if agg_expr.has_name(): return agg.alias(agg_expr.get_name()) return agg @@ -412,7 +417,7 @@ def _compile_agg(t, agg_expr, scope, timecontext, *, context, **kwargs): def compile_aggregation(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_table = t.translate(op.table, scope, timecontext) + src_table = t.translate(op.table, scope, timecontext, **kwargs) if op.predicates: src_table = src_table.filter( @@ -420,12 +425,13 @@ def compile_aggregation(t, expr, scope, timecontext, **kwargs): functools.reduce(operator.and_, op.predicates), scope, timecontext, + **kwargs, ) ) if op.by: context = AggregationContext.GROUP - bys = [t.translate(b, scope, timecontext) for b in op.by] + bys = [t.translate(b, scope, timecontext, **kwargs) for b in op.by] src_table = src_table.groupby(*bys) else: context = AggregationContext.ENTIRE @@ -465,30 +471,30 @@ def compile_difference(t, expr, scope, timecontext, **kwargs): @compiles(ops.Contains) def compile_contains(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.value, scope, timecontext) - return col.isin(t.translate(op.options, scope, timecontext)) + col = t.translate(op.value, scope, timecontext, **kwargs) + return col.isin(t.translate(op.options, scope, timecontext, **kwargs)) @compiles(ops.NotContains) def compile_not_contains(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.value, scope, timecontext) - return ~(col.isin(t.translate(op.options, scope, timecontext))) + col = t.translate(op.value, scope, timecontext, **kwargs) + return ~(col.isin(t.translate(op.options, scope, timecontext, **kwargs))) @compiles(ops.StartsWith) def compile_startswith(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.arg, scope, timecontext) - start = t.translate(op.start, scope, timecontext) + col = t.translate(op.arg, scope, timecontext, **kwargs) + start = t.translate(op.start, scope, timecontext, **kwargs) return col.startswith(start) @compiles(ops.EndsWith) def compile_endswith(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.arg, scope, timecontext) - end = t.translate(op.end, scope, timecontext) + col = t.translate(op.arg, scope, timecontext, **kwargs) + end = t.translate(op.end, scope, timecontext, **kwargs) return col.startswith(end) @@ -504,12 +510,12 @@ def compile_aggregator( ): op = expr.op() if (where := getattr(op, 'where', None)) is not None: - condition = t.translate(where, scope, timecontext) + condition = t.translate(where, scope, timecontext, **kwargs) else: condition = None def translate_arg(arg): - src_col = t.translate(arg, scope, timecontext) + src_col = t.translate(arg, scope, timecontext, **kwargs) if condition is not None: src_col = F.when(condition, src_col) @@ -535,7 +541,9 @@ def translate_arg(arg): (src_col,) = src_cols return src_col.select(col) (table_op,) = op.root_tables() - return t.translate(table_op.to_expr(), scope, timecontext).select(col) + return t.translate( + table_op.to_expr(), scope, timecontext, **kwargs + ).select(col) @compiles(ops.GroupConcat) @@ -783,7 +791,7 @@ def compile_arbitrary(t, expr, scope, timecontext, context=None, **kwargs): def compile_coalesce(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_columns = t.translate(op.arg, scope, timecontext) + src_columns = t.translate(op.arg, scope, timecontext, **kwargs) if len(src_columns) == 1: return src_columns[0] else: @@ -794,7 +802,7 @@ def compile_coalesce(t, expr, scope, timecontext, **kwargs): def compile_greatest(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_columns = t.translate(op.arg, scope, timecontext) + src_columns = t.translate(op.arg, scope, timecontext, **kwargs) if len(src_columns) == 1: return src_columns[0] else: @@ -805,7 +813,7 @@ def compile_greatest(t, expr, scope, timecontext, **kwargs): def compile_least(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_columns = t.translate(op.arg, scope, timecontext) + src_columns = t.translate(op.arg, scope, timecontext, **kwargs) if len(src_columns) == 1: return src_columns[0] else: @@ -816,7 +824,7 @@ def compile_least(t, expr, scope, timecontext, **kwargs): def compile_abs(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.abs(src_column) @@ -824,14 +832,14 @@ def compile_abs(t, expr, scope, timecontext, **kwargs): def compile_clip(t, expr, scope, timecontext, **kwargs): op = expr.op() spark_dtype = ibis_dtype_to_spark_dtype(expr.type()) - col = t.translate(op.arg, scope, timecontext) + col = t.translate(op.arg, scope, timecontext, **kwargs) upper = ( - t.translate(op.upper, scope, timecontext) + t.translate(op.upper, scope, timecontext, **kwargs) if op.upper is not None else float('inf') ) lower = ( - t.translate(op.lower, scope, timecontext) + t.translate(op.lower, scope, timecontext, **kwargs) if op.lower is not None else float('-inf') ) @@ -858,7 +866,7 @@ def clip(column, lower_value, upper_value): def compile_round(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) scale = ( t.translate(op.digits, scope, timecontext, raw=True) if op.digits is not None @@ -874,7 +882,7 @@ def compile_round(t, expr, scope, timecontext, **kwargs): def compile_ceil(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.ceil(src_column) @@ -882,7 +890,7 @@ def compile_ceil(t, expr, scope, timecontext, **kwargs): def compile_floor(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.floor(src_column) @@ -890,7 +898,7 @@ def compile_floor(t, expr, scope, timecontext, **kwargs): def compile_exp(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.exp(src_column) @@ -898,7 +906,7 @@ def compile_exp(t, expr, scope, timecontext, **kwargs): def compile_sign(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.when(src_column == 0, F.lit(0.0)).otherwise( F.when(src_column > 0, F.lit(1.0)).otherwise(-1.0) @@ -909,7 +917,7 @@ def compile_sign(t, expr, scope, timecontext, **kwargs): def compile_sqrt(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.sqrt(src_column) @@ -917,7 +925,7 @@ def compile_sqrt(t, expr, scope, timecontext, **kwargs): def compile_log(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) raw_base = t.translate(op.base, scope, timecontext, raw=True) try: base = float(raw_base) @@ -931,7 +939,7 @@ def compile_log(t, expr, scope, timecontext, **kwargs): def compile_ln(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.log(src_column) @@ -939,7 +947,7 @@ def compile_ln(t, expr, scope, timecontext, **kwargs): def compile_log2(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.log2(src_column) @@ -947,7 +955,7 @@ def compile_log2(t, expr, scope, timecontext, **kwargs): def compile_log10(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.log10(src_column) @@ -955,8 +963,8 @@ def compile_log10(t, expr, scope, timecontext, **kwargs): def compile_modulus(t, expr, scope, timecontext, **kwargs): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = t.translate(op.right, scope, timecontext) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = t.translate(op.right, scope, timecontext, **kwargs) return left % right @@ -964,7 +972,7 @@ def compile_modulus(t, expr, scope, timecontext, **kwargs): def compile_negate(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) if expr.type() == dtypes.boolean: return ~src_column return -src_column @@ -974,8 +982,8 @@ def compile_negate(t, expr, scope, timecontext, **kwargs): def compile_add(t, expr, scope, timecontext, **kwargs): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = t.translate(op.right, scope, timecontext) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = t.translate(op.right, scope, timecontext, **kwargs) return left + right @@ -983,8 +991,8 @@ def compile_add(t, expr, scope, timecontext, **kwargs): def compile_divide(t, expr, scope, timecontext, **kwargs): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = t.translate(op.right, scope, timecontext) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = t.translate(op.right, scope, timecontext, **kwargs) return left / right @@ -992,8 +1000,8 @@ def compile_divide(t, expr, scope, timecontext, **kwargs): def compile_floor_divide(t, expr, scope, timecontext, **kwargs): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = t.translate(op.right, scope, timecontext) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = t.translate(op.right, scope, timecontext, **kwargs) return F.floor(left / right) @@ -1001,8 +1009,8 @@ def compile_floor_divide(t, expr, scope, timecontext, **kwargs): def compile_power(t, expr, scope, timecontext, **kwargs): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = t.translate(op.right, scope, timecontext) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = t.translate(op.right, scope, timecontext, **kwargs) return F.pow(left, right) @@ -1010,7 +1018,7 @@ def compile_power(t, expr, scope, timecontext, **kwargs): def compile_isnan(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.isnan(src_column) | F.isnull(src_column) @@ -1018,7 +1026,7 @@ def compile_isnan(t, expr, scope, timecontext, **kwargs): def compile_isinf(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return (src_column == float('inf')) | (src_column == float('-inf')) @@ -1026,7 +1034,7 @@ def compile_isinf(t, expr, scope, timecontext, **kwargs): def compile_uppercase(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.upper(src_column) @@ -1034,7 +1042,7 @@ def compile_uppercase(t, expr, scope, timecontext, **kwargs): def compile_lowercase(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.lower(src_column) @@ -1042,7 +1050,7 @@ def compile_lowercase(t, expr, scope, timecontext, **kwargs): def compile_reverse(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.reverse(src_column) @@ -1050,7 +1058,7 @@ def compile_reverse(t, expr, scope, timecontext, **kwargs): def compile_strip(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.trim(src_column) @@ -1058,7 +1066,7 @@ def compile_strip(t, expr, scope, timecontext, **kwargs): def compile_lstrip(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.ltrim(src_column) @@ -1066,7 +1074,7 @@ def compile_lstrip(t, expr, scope, timecontext, **kwargs): def compile_rstrip(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.rtrim(src_column) @@ -1074,14 +1082,14 @@ def compile_rstrip(t, expr, scope, timecontext, **kwargs): def compile_capitalize(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.initcap(src_column) @compiles(ops.Substring) def compile_substring(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) start = t.translate(op.start, scope, timecontext, raw=True) + 1 length = t.translate(op.length, scope, timecontext, raw=True) @@ -1100,7 +1108,7 @@ def compile_substring(t, expr, scope, timecontext, **kwargs): def compile_string_length(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.length(src_column) @@ -1112,8 +1120,8 @@ def compile_str_right(t, expr, scope, timecontext, **kwargs): def str_right(s, nchars): return s[-nchars:] - src_column = t.translate(op.arg, scope, timecontext) - nchars_column = t.translate(op.nchars, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) + nchars_column = t.translate(op.nchars, scope, timecontext, **kwargs) return str_right(src_column, nchars_column) @@ -1125,8 +1133,8 @@ def compile_repeat(t, expr, scope, timecontext, **kwargs): def repeat(s, times): return s * times - src_column = t.translate(op.arg, scope, timecontext) - times_column = t.translate(op.times, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) + times_column = t.translate(op.times, scope, timecontext, **kwargs) return repeat(src_column, times_column) @@ -1138,13 +1146,17 @@ def compile_string_find(t, expr, scope, timecontext, **kwargs): def str_find(s, substr, start, end): return s.find(substr, start, end) - src_column = t.translate(op.arg, scope, timecontext) - substr_column = t.translate(op.substr, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) + substr_column = t.translate(op.substr, scope, timecontext, **kwargs) start_column = ( - t.translate(op.start, scope, timecontext) if op.start else F.lit(None) + t.translate(op.start, scope, timecontext, **kwargs) + if op.start + else F.lit(None) ) end_column = ( - t.translate(op.end, scope, timecontext) if op.end else F.lit(None) + t.translate(op.end, scope, timecontext, **kwargs) + if op.end + else F.lit(None) ) return str_find(src_column, substr_column, start_column, end_column) @@ -1153,7 +1165,7 @@ def str_find(s, substr, start, end): def compile_translate(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) from_str = op.from_str.op().value to_str = op.to_str.op().value return F.translate(src_column, from_str, to_str) @@ -1163,7 +1175,7 @@ def compile_translate(t, expr, scope, timecontext, **kwargs): def compile_lpad(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) length = op.length.op().value pad = op.pad.op().value return F.lpad(src_column, length, pad) @@ -1173,7 +1185,7 @@ def compile_lpad(t, expr, scope, timecontext, **kwargs): def compile_rpad(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) length = op.length.op().value pad = op.pad.op().value return F.rpad(src_column, length, pad) @@ -1187,8 +1199,8 @@ def compile_string_join(t, expr, scope, timecontext, **kwargs): def join(sep, arr): return sep.join(arr) - sep_column = t.translate(op.sep, scope, timecontext) - arg = t.translate(op.arg, scope, timecontext) + sep_column = t.translate(op.sep, scope, timecontext, **kwargs) + arg = t.translate(op.arg, scope, timecontext, **kwargs) return join(sep_column, F.array(arg)) @@ -1202,8 +1214,8 @@ def compile_regex_search(t, expr, scope, timecontext, **kwargs): def regex_search(s, pattern): return True if re.search(pattern, s) else False - src_column = t.translate(op.arg, scope, timecontext) - pattern = t.translate(op.pattern, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) + pattern = t.translate(op.pattern, scope, timecontext, **kwargs) return regex_search(src_column, pattern) @@ -1211,7 +1223,7 @@ def regex_search(s, pattern): def compile_regex_extract(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) pattern = op.pattern.op().value idx = op.index.op().value return F.regexp_extract(src_column, pattern, idx) @@ -1221,7 +1233,7 @@ def compile_regex_extract(t, expr, scope, timecontext, **kwargs): def compile_regex_replace(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) pattern = op.pattern.op().value replacement = op.replacement.op().value return F.regexp_replace(src_column, pattern, replacement) @@ -1236,7 +1248,7 @@ def compile_string_replace(t, expr, scope, timecontext, **kwargs): def compile_string_split(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) delimiter = op.delimiter.op().value return F.split(src_column, delimiter) @@ -1245,7 +1257,7 @@ def compile_string_split(t, expr, scope, timecontext, **kwargs): def compile_string_concat(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_columns = t.translate(op.arg, scope, timecontext) + src_columns = t.translate(op.arg, scope, timecontext, **kwargs) return F.concat(*src_columns) @@ -1253,7 +1265,7 @@ def compile_string_concat(t, expr, scope, timecontext, **kwargs): def compile_string_ascii(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.ascii(src_column) @@ -1261,7 +1273,7 @@ def compile_string_ascii(t, expr, scope, timecontext, **kwargs): def compile_string_like(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) pattern = op.pattern.op().value return src_column.like(pattern) @@ -1269,7 +1281,13 @@ def compile_string_like(t, expr, scope, timecontext, **kwargs): @compiles(ops.ValueList) def compile_value_list(t, expr, scope, timecontext, **kwargs): op = expr.op() - return [t.translate(col, scope, timecontext) for col in op.values] + # ignore the `raw` argument when compiling a list, otherwise pyspark fails, + # because it doesn't automatically upcast literals into expressions + kwargs.pop("raw", None) + return [ + t.translate(col, scope, timecontext, raw=False, **kwargs) + for col in op.values + ] @compiles(ops.InnerJoin) @@ -1323,9 +1341,9 @@ def compile_join(t, expr, scope, timecontext, *, how): @compiles(ops.Distinct) -def compile_distinct(t, expr, scope, timecontext): +def compile_distinct(t, expr, scope, timecontext, **kwargs): op = expr.op() - return t.translate(op.table, scope, timecontext).distinct() + return t.translate(op.table, scope, timecontext, **kwargs).distinct() def _canonicalize_interval(t, interval, scope, timecontext, **kwargs): @@ -1356,7 +1374,7 @@ def compile_window_op(t, expr, scope, timecontext, **kwargs): grouping_keys = [ key_op.name if isinstance(key_op, ops.TableColumn) - else t.translate(key, scope, timecontext) + else t.translate(key, scope, timecontext, **kwargs) for key, key_op in zip( group_by, map(operator.methodcaller('op'), group_by) ) @@ -1420,7 +1438,7 @@ def compile_window_op(t, expr, scope, timecontext, **kwargs): def _handle_shift_operation(t, expr, scope, timecontext, *, fn, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) default = op.default.op().value if op.default is not None else op.default offset = op.offset.op().value if op.offset is not None else op.offset @@ -1476,21 +1494,21 @@ def compile_ntile(t, expr, scope, timecontext, **kwargs): @compiles(ops.FirstValue) def compile_first_value(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.first(src_column) @compiles(ops.LastValue) def compile_last_value(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.last(src_column) @compiles(ops.NthValue) def compile_nth_value(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) nth = t.translate(op.nth, scope, timecontext, raw=True) return F.nth_value(src_column, nth + 1) @@ -1518,7 +1536,7 @@ def compile_row_number(t, expr, scope, timecontext, **kwargs): @compiles(ops.Date) def compile_date(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.to_date(src_column).cast('timestamp') @@ -1526,7 +1544,7 @@ def _extract_component_from_datetime( t, expr, scope, timecontext, *, extract_fn, **kwargs ): op = expr.op() - date_col = t.translate(op.arg, scope, timecontext) + date_col = t.translate(op.arg, scope, timecontext, **kwargs) return extract_fn(date_col).cast('integer') @@ -1618,7 +1636,7 @@ def compile_date_truncate(t, expr, scope, timecontext, **kwargs): f'{op.unit!r} unit is not supported in timestamp truncate' ) - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.date_trunc(unit, src_column) @@ -1636,14 +1654,14 @@ def compile_strftime(t, expr, scope, timecontext, **kwargs): def strftime(timestamps): return timestamps.dt.strftime(format_str) - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return strftime(src_column) @compiles(ops.TimestampFromUNIX) def compile_timestamp_from_unix(t, expr, scope, timecontext, **kwargs): op = expr.op() - unixtime = t.translate(op.arg, scope, timecontext) + unixtime = t.translate(op.arg, scope, timecontext, **kwargs) if not op.unit: return F.to_timestamp(F.from_unixtime(unixtime)) elif op.unit == 's': @@ -1657,7 +1675,7 @@ def compile_timestamp_from_unix(t, expr, scope, timecontext, **kwargs): @compiles(ops.TimestampNow) -def compile_timestamp_now(t, expr, scope, timecontext, **kwargs): +def compile_timestamp_now(t, expr, scope, timecontext, **_): return F.current_timestamp() @@ -1665,7 +1683,7 @@ def compile_timestamp_now(t, expr, scope, timecontext, **kwargs): def compile_string_to_timestamp(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) fmt = op.format_str.op().value if op.timezone is not None and op.timezone.op().value != "UTC": @@ -1685,7 +1703,7 @@ def compile_day_of_week_index(t, expr, scope, timecontext, **kwargs): def day_of_week(s): return s.dt.dayofweek - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return day_of_week(src_column.cast('timestamp')) @@ -1697,24 +1715,24 @@ def compiles_day_of_week_name(t, expr, scope, timecontext, **kwargs): def day_name(s): return s.dt.day_name() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return day_name(src_column.cast('timestamp')) def _get_interval_col( - t, interval_ibis_expr, scope, timecontext, allowed_units=None + t, interval_ibis_expr, scope, timecontext, allowed_units=None, **kwargs ): # if interval expression is a binary op, translate expression into # an interval column and return if isinstance(interval_ibis_expr.op(), ops.IntervalBinary): - return t.translate(interval_ibis_expr, scope, timecontext) + return t.translate(interval_ibis_expr, scope, timecontext, **kwargs) # otherwise, translate expression into a literal op and construct # interval column from literal value and dtype if isinstance(interval_ibis_expr.op(), ops.Literal): op = interval_ibis_expr.op() else: - op = t.translate(interval_ibis_expr, scope, timecontext).op() + op = t.translate(interval_ibis_expr, scope, timecontext, **kwargs).op() dtype = op.dtype if not isinstance(dtype, dtypes.Interval): @@ -1746,8 +1764,10 @@ def _compile_datetime_binop( ): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = _get_interval_col(t, op.right, scope, timecontext, allowed_units) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = _get_interval_col( + t, op.right, scope, timecontext, allowed_units, **kwargs + ) return fn(left, right) @@ -1827,8 +1847,8 @@ def compile_timestamp_diff(t, expr, scope, timecontext, **kwargs): def _compile_interval_binop(t, expr, scope, timecontext, *, fn, **kwargs): op = expr.op() - left = _get_interval_col(t, op.left, scope, timecontext) - right = _get_interval_col(t, op.right, scope, timecontext) + left = _get_interval_col(t, op.left, scope, timecontext, **kwargs) + right = _get_interval_col(t, op.right, scope, timecontext, **kwargs) return fn(left, right) @@ -1861,7 +1881,7 @@ def compile_interval_from_integer(t, expr, scope, timecontext, **kwargs): def compile_array_column(t, expr, scope, timecontext, **kwargs): op = expr.op() - cols = t.translate(op.cols, scope, timecontext) + cols = t.translate(op.cols, scope, timecontext, **kwargs) return F.array(cols) @@ -1869,7 +1889,7 @@ def compile_array_column(t, expr, scope, timecontext, **kwargs): def compile_array_length(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.size(src_column) @@ -1884,7 +1904,7 @@ def compile_array_slice(t, expr, scope, timecontext, **kwargs): def array_slice(array): return array[start:stop] - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return array_slice(src_column) @@ -1892,7 +1912,7 @@ def array_slice(array): def compile_array_index(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) index = op.index.op().value + 1 return F.element_at(src_column, index) @@ -1901,8 +1921,8 @@ def compile_array_index(t, expr, scope, timecontext, **kwargs): def compile_array_concat(t, expr, scope, timecontext, **kwargs): op = expr.op() - left = t.translate(op.left, scope, timecontext) - right = t.translate(op.right, scope, timecontext) + left = t.translate(op.left, scope, timecontext, **kwargs) + right = t.translate(op.right, scope, timecontext, **kwargs) return F.concat(left, right) @@ -1910,7 +1930,7 @@ def compile_array_concat(t, expr, scope, timecontext, **kwargs): def compile_array_repeat(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) times = op.times.op().value return F.flatten(F.array_repeat(src_column, times)) @@ -1919,7 +1939,7 @@ def compile_array_repeat(t, expr, scope, timecontext, **kwargs): def compile_array_collect(t, expr, scope, timecontext, **kwargs): op = expr.op() - src_column = t.translate(op.arg, scope, timecontext) + src_column = t.translate(op.arg, scope, timecontext, **kwargs) return F.collect_list(src_column) @@ -1934,37 +1954,37 @@ def compile_null_literal(t, expr, scope, timecontext, **kwargs): @compiles(ops.IfNull) def compile_if_null(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.arg, scope, timecontext) - ifnull_col = t.translate(op.ifnull_expr, scope, timecontext) + col = t.translate(op.arg, scope, timecontext, **kwargs) + ifnull_col = t.translate(op.ifnull_expr, scope, timecontext, **kwargs) return F.when(col.isNull() | F.isnan(col), ifnull_col).otherwise(col) @compiles(ops.NullIf) def compile_null_if(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.arg, scope, timecontext) - nullif_col = t.translate(op.null_if_expr, scope, timecontext) + col = t.translate(op.arg, scope, timecontext, **kwargs) + nullif_col = t.translate(op.null_if_expr, scope, timecontext, **kwargs) return F.when(col == nullif_col, F.lit(None)).otherwise(col) @compiles(ops.IsNull) def compile_is_null(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.arg, scope, timecontext) + col = t.translate(op.arg, scope, timecontext, **kwargs) return F.isnull(col) | F.isnan(col) @compiles(ops.NotNull) def compile_not_null(t, expr, scope, timecontext, **kwargs): op = expr.op() - col = t.translate(op.arg, scope, timecontext) + col = t.translate(op.arg, scope, timecontext, **kwargs) return ~F.isnull(col) & ~F.isnan(col) @compiles(ops.DropNa) def compile_dropna_table(t, expr, scope, timecontext, **kwargs): op = expr.op() - table = t.translate(op.table, scope, timecontext) + table = t.translate(op.table, scope, timecontext, **kwargs) subset = [col.get_name() for col in op.subset] if op.subset else None return table.dropna(how=op.how, subset=subset) @@ -1972,7 +1992,7 @@ def compile_dropna_table(t, expr, scope, timecontext, **kwargs): @compiles(ops.FillNa) def compile_fillna_table(t, expr, scope, timecontext, **kwargs): op = expr.op() - table = t.translate(op.table, scope, timecontext) + table = t.translate(op.table, scope, timecontext, **kwargs) raw_replacements = op.replacements replacements = ( dict(raw_replacements) @@ -1991,7 +2011,9 @@ def compile_elementwise_udf(t, expr, scope, timecontext, **kwargs): spark_output_type = spark_dtype(op.return_type) func = op.func spark_udf = pandas_udf(func, spark_output_type, PandasUDFType.SCALAR) - func_args = (t.translate(arg, scope, timecontext) for arg in op.func_args) + func_args = ( + t.translate(arg, scope, timecontext, **kwargs) for arg in op.func_args + ) return spark_udf(*func_args) @@ -2003,13 +2025,17 @@ def compile_reduction_udf(t, expr, scope, timecontext, context=None, **kwargs): spark_udf = pandas_udf( op.func, spark_output_type, PandasUDFType.GROUPED_AGG ) - func_args = (t.translate(arg, scope, timecontext) for arg in op.func_args) + func_args = ( + t.translate(arg, scope, timecontext, **kwargs) for arg in op.func_args + ) col = spark_udf(*func_args) if context: return col else: - src_table = t.translate(op.func_args[0].op().table, scope, timecontext) + src_table = t.translate( + op.func_args[0].op().table, scope, timecontext, **kwargs + ) return src_table.agg(col) @@ -2036,26 +2062,24 @@ def compile_searched_case(t, expr, scope, timecontext, **kwargs): @compiles(ops.View) -def compile_view(t, expr, scope, timecontext, **kwargs): +def compile_view(t, expr, scope, timecontext, session, **kwargs): op = expr.op() name = op.name child = op.child - backend = child._find_backend() - tables = backend._session.catalog.listTables() + tables = session.catalog.listTables() if any(name == table.name and not table.isTemporary for table in tables): raise ValueError( f"table or non-temporary view `{name}` already exists" ) - result = t.translate(child, scope, timecontext, **kwargs) + result = t.translate(child, scope, timecontext, session=session, **kwargs) result.createOrReplaceTempView(name) return result @compiles(ops.SQLStringView) -def compile_sql_view(t, expr, scope, timecontext, **kwargs): +def compile_sql_view(t, expr, scope, timecontext, session, **kwargs): op = expr.op() - backend = op.child._find_backend() - result = backend._session.sql(op.query) + result = session.sql(op.query) result.createOrReplaceTempView(op.name) return result @@ -2139,3 +2163,16 @@ def compile_where(t, expr, scope, timecontext, **kwargs): @compiles(ops.RandomScalar) def compile_random(*args, **kwargs): return F.rand() + + +@compiles(PandasInMemoryTable) +def compile_in_memory_table(t, expr, scope, timecontext, session, **kwargs): + op = expr.op() + fields = [ + pt.StructField(name, ibis_dtype_to_spark_dtype(dtype), dtype.nullable) + for name, dtype in op.schema.items() + ] + return session.createDataFrame( + data=op.data._df, + schema=pt.StructType(fields), + ) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 04b88128dfe6..8e0dfd932781 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -558,3 +558,78 @@ def test_invalid_connect(): def test_deprecated_path_argument(backend, tmp_path): with pytest.warns(UserWarning, match="The `path` argument is deprecated"): getattr(ibis, backend.name()).connect(path=str(tmp_path / "test.db")) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param( + ibis.memtable([(1, 2.0, "3")], columns=list("abc")), + pd.DataFrame([(1, 2.0, "3")], columns=list("abc")), + id="simple", + ), + param( + ibis.memtable([(1, 2.0, "3")]), + pd.DataFrame([(1, 2.0, "3")], columns=["col0", "col1", "col2"]), + id="simple_auto_named", + ), + param( + ibis.memtable( + [(1, 2.0, "3")], + schema=ibis.schema(dict(a="int8", b="float32", c="string")), + ), + pd.DataFrame([(1, 2.0, "3")], columns=list("abc")).astype( + {"a": "int8", "b": "float32"} + ), + id="simple_schema", + ), + param( + ibis.memtable( + pd.DataFrame({"a": [1], "b": [2.0], "c": ["3"]}).astype( + {"a": "int8", "b": "float32"} + ) + ), + pd.DataFrame([(1, 2.0, "3")], columns=list("abc")).astype( + {"a": "int8", "b": "float32"} + ), + id="dataframe", + ), + ], +) +@pytest.mark.notyet( + ["clickhouse"], + reason="ClickHouse doesn't support a VALUES construct", +) +@pytest.mark.notyet( + ["mysql", "sqlite"], + reason="SQLAlchemy generates incorrect code for `VALUES` projections.", + raises=(sa.exc.ProgrammingError, sa.exc.OperationalError), +) +@pytest.mark.notimpl(["dask", "datafusion", "pandas"]) +def test_in_memory_table(backend, con, expr, expected): + result = con.execute(expr) + backend.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "t", + [ + param( + ibis.memtable([("a", 1.0)], columns=["a", "b"]), + id="python", + ), + param( + ibis.memtable(pd.DataFrame([("a", 1.0)], columns=["a", "b"])), + id="pandas", + ), + ], +) +@pytest.mark.notimpl(["clickhouse", "dask", "datafusion", "pandas"]) +def test_create_from_in_memory_table(con, t): + tmp_name = guid() + con.create_table(tmp_name, t) + try: + assert tmp_name in con.list_tables() + finally: + con.drop_table(tmp_name) + assert tmp_name not in con.list_tables() diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index 3c2e3aba59b1..661b09d34af4 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -66,6 +66,17 @@ def validate(self, this, arg): return self.validator(arg, this=this) +class Immutable(Hashable): + + __slots__ = () + + def __setattr__(self, name: str, _: Any) -> None: + raise TypeError( + f"Attribute {name!r} cannot be assigned to immutable instance of " + f"type {type(self)}" + ) + + class AnnotableMeta(BaseMeta): """ Metaclass to turn class annotations into a validatable function signature. @@ -132,7 +143,7 @@ def __new__(metacls, clsname, bases, dct): return super().__new__(metacls, clsname, bases, attribs) -class Annotable(Base, Hashable, metaclass=AnnotableMeta): +class Annotable(Base, Immutable, metaclass=AnnotableMeta): """Base class for objects with custom validation rules.""" __slots__ = ("args", "_hash") @@ -182,12 +193,6 @@ def __hash__(self): def __eq__(self, other): return super().__eq__(other) - def __setattr__(self, name: str, _: Any) -> None: - raise TypeError( - f"Attribute {name!r} cannot be assigned to immutable instance of " - f"type {type(self)}" - ) - def __repr__(self) -> str: args = ", ".join( f"{name}={value!r}" diff --git a/ibis/expr/api.py b/ibis/expr/api.py index d845063ff15f..f84830b73477 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -188,6 +188,7 @@ 'least', 'literal', 'map', + 'memtable', 'NA', 'negate', 'now', @@ -335,10 +336,10 @@ def schema( def table( - schema: SupportsSchema, + schema: SupportsSchema | None = None, name: str | None = None, ) -> ir.Table: - """Create an unbound table for building expressions without data. + """Create a table literal or an abstract table without data. Parameters ---------- @@ -350,10 +351,106 @@ def table( Returns ------- Table - An unbound table expression + A table expression + + Examples + -------- + Create a table with no data backing it + + >>> t = ibis.table(schema=dict(a="int", b="string")) + >>> t + UnboundTable: unbound_table_0 + a int64 + b string + """ + if schema is not None: + schema = sch.schema(schema) + return ops.UnboundTable(schema=schema, name=name).to_expr() + + +@functools.singledispatch +def memtable( + data, + *, + columns: Iterable[str] | None = None, + schema: SupportsSchema | None = None, + name: str | None = None, +) -> Table: + """Construct an ibis table expression from in-memory data. + + Parameters + ---------- + data + Any data accepted by the `pandas.DataFrame` constructor. + + The use of `DataFrame` underneath should **not** be relied upon and is + free to change across non-major releases. + columns + Optional [`Iterable`][typing.Iterable] of [`str`][str] column names. + schema + Optional [`Schema`][ibis.expr.schema.Schema]. The functions use `data` + to infer a schema if not passed. + name + Optional name of the table. + + Returns + ------- + Table + A table expression backed by in-memory data. + + Examples + -------- + >>> import ibis + >>> t = ibis.memtable([{"a": 1}, {"a": 2}]) + >>> t + + >>> t = ibis.memtable([{"a": 1, "b": "foo"}, {"a": 2, "b": "baz"}]) + >>> t + PandasInMemoryTable + data: + ((1, 'foo'), (2, 'baz')) + schema: + a int8 + b string + + Create a table literal without column names embedded in the data and pass + `columns` + + >>> t = ibis.memtable([(1, "foo"), (2, "baz")], columns=["a", "b"]) + >>> t + PandasInMemoryTable + data: + ((1, 'foo'), (2, 'baz')) + schema: + a int8 + b string """ - node = ops.UnboundTable(sch.schema(schema), name=name) - return node.to_expr() + if columns is not None and schema is not None: + raise NotImplementedError( + "passing `columns` and schema` is ambiguous; " + "pass one or the other but not both" + ) + df = pd.DataFrame(data, columns=columns) + if isinstance(data, (list, tuple)) and columns is None: + df = df.rename(columns={col: f"col{col:d}" for col in df.columns}) + return memtable(df, name=name, schema=schema) + + +@memtable.register(pd.DataFrame) +def _memtable_from_dataframe( + df: pd.DataFrame, + *, + name: str | None = None, + schema: SupportsSchema | None = None, +) -> Table: + from ibis.backends.pandas.client import DataFrameProxy, PandasInMemoryTable + + op = PandasInMemoryTable( + name=name, + schema=sch.infer(df) if schema is None else schema, + data=DataFrameProxy(df), + ) + return op.to_expr() def desc(expr: ir.Column | str) -> ir.SortExpr | ops.DeferredSortKey: diff --git a/ibis/expr/format.py b/ibis/expr/format.py index 857224d0b5a4..14cf4f1cd0f9 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -6,6 +6,8 @@ import types from typing import Any, Callable, Deque, Iterable, Mapping, Tuple +import rich.pretty + import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -430,6 +432,20 @@ def _fmt_table_op_limit(op: ops.Limit, *, aliases: Aliases, **_: Any) -> str: return f"{op.__class__.__name__}[{', '.join(params)}]" +@fmt_table_op.register +def _fmt_table_op_in_memory_table(op: ops.InMemoryTable, **_: Any) -> str: + # arbitrary limit, but some value is needed to avoid a huge repr + max_length = 10 + pretty_data = rich.pretty.pretty_repr(op.data, max_length=max_length) + return "\n".join( + [ + op.__class__.__name__, + util.indent("data:", spaces=2), + util.indent(pretty_data, spaces=4), + ] + ) + + @functools.singledispatch def fmt_selection_column(value_expr: ir.Value, **_: Any) -> str: assert False, ( @@ -653,6 +669,23 @@ def _fmt_value_table_node( return f"{aliases[op.table.op()]}" +_JOIN_SYMS = { + ops.InnerJoin: "⋈", + ops.LeftJoin: "⟕", + ops.RightJoin: "⟖", + ops.OuterJoin: "⟗", + ops.CrossJoin: "×", +} + + +@fmt_value.register +def _fmt_value_join(op: ops.Join, *, aliases: Aliases, **_: Any) -> str: + """Format a join as value.""" + left = aliases[op.left.op()] + right = aliases[op.right.op()] + return f"{left} {_JOIN_SYMS[type(op)]} {right}" + + @fmt_value.register def _fmt_value_string_sql_like( op: ops.StringSQLLike, *, aliases: Aliases diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index da36eee06a53..dc6d3e8fbf63 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -2,6 +2,7 @@ import collections import itertools +from abc import abstractmethod from functools import cached_property from public import public @@ -96,6 +97,17 @@ def blocks(self): return True +@public +class InMemoryTable(TableNode, sch.HasSchema): + name = rlz.optional(rlz.instance_of(str)) + schema = rlz.instance_of(sch.Schema) + + @property + @abstractmethod + def data(self): + """Return the data of an in-memory table.""" + + def _make_distinct_join_predicates(left, right, predicates): import ibis.expr.analysis as L diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index f697e8904591..51402be4870e 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -351,25 +351,29 @@ def table(arg, *, schema=None, **kwargs): Parameters ---------- - schema : Union[sch.Schema, List[Tuple[str, dt.DataType], None] + schema A validator for the table's columns. Only column subset validators are currently supported. Accepts any arguments that `sch.schema` accepts. See the example for usage. - arg : The validatable argument. - - Examples - -------- - The following op will accept an argument named ``'table'``. Note that the - ``schema`` argument specifies rules for columns that are required to be in - the table: ``time``, ``group`` and ``value1``. These must match the types - specified in the column rules. Column ``value2`` is optional, but if - present it must be of the specified type. The table may have extra columns - not specified in the schema. + arg + An argument + + The following op will accept an argument named `'table'`. Note that the + `schema` argument specifies rules for columns that are required to be in + the table: `time`, `group` and `value1`. These must match the types + specified in the column rules. Column `value2` is optional, but if present + it must be of the specified type. The table may have extra columns not + specified in the schema. """ + import ibis + if not isinstance(arg, ir.Table): - raise com.IbisTypeError( - f'Argument is not a table; got type {type(arg).__name__}' - ) + try: + return ibis.table(data=arg, schema=schema) + except Exception as e: + raise com.IbisTypeError( + f'Argument is not a table; got type {type(arg).__name__}' + ) from e if schema is not None: if arg.schema() >= sch.schema(schema): diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index c9414ca7e786..121699381096 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -1110,7 +1110,8 @@ def join( if isinstance(predicates, ir.Expr): predicates = an.flatten_predicate(predicates) - expr = klass(left, right, predicates).to_expr() + op = klass(left, right, predicates) + expr = op.to_expr() # semi/anti join only give access to the left table's fields, so # there's never overlap @@ -1119,8 +1120,8 @@ def join( return ops.relations._dedup_join_columns( expr, - left=left, - right=right, + left=op.left, + right=op.right, suffixes=suffixes, ) diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index d4a29dde3bfc..bca7b31a899c 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -1535,3 +1535,12 @@ def test_exprs_to_select(): with pytest.warns(FutureWarning, match="Passing `exprs`"): result = t.select(exprs=exprs) assert result.equals(t.select(len=t.a.length())) + + +def test_python_table_ambiguous(): + with pytest.raises(NotImplementedError): + ibis.memtable( + [(1,)], + schema=ibis.schema(dict(a="int8")), + columns=["a"], + )