diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index f8cc80e5c76f..b9055e3aa292 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -13,7 +13,9 @@ from ibis.backends.base.sqlglot.compiler import STAR if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterable, Iterator, Mapping + + import pyarrow as pa import ibis.expr.datatypes as dt import ibis.expr.types as ir @@ -60,7 +62,7 @@ def table( ).to_expr() def _to_sqlglot( - self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any + self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any ): """Compile an Ibis expression to a sqlglot object.""" table_expr = expr.as_table() @@ -206,13 +208,17 @@ def _clean_up_cached_table(self, op): self.drop_table(op.name) def execute( - self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any + self, + expr: ir.Expr, + params: Mapping | None = None, + limit: str | None = "default", + **kwargs: Any, ) -> Any: """Execute an expression.""" self._run_pre_execute_hooks(expr) table = expr.as_table() - sql = self.compile(table, limit=limit, **kwargs) + sql = self.compile(table, params=params, limit=limit, **kwargs) schema = table.schema() @@ -236,3 +242,63 @@ def drop_table( ) with self._safe_raw_sql(drop_stmt): pass + + def _cursor_batches( + self, + expr: ir.Expr, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + chunk_size: int = 1 << 20, + ) -> Iterable[list]: + self._run_pre_execute_hooks(expr) + + with self._safe_raw_sql( + self.compile(expr, limit=limit, params=params) + ) as cursor: + while batch := cursor.fetchmany(chunk_size): + yield batch + + def to_pyarrow_batches( + self, + expr: ir.Expr, + *, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + chunk_size: int = 1_000_000, + **_: Any, + ) -> pa.ipc.RecordBatchReader: + """Execute expression and return an iterator of pyarrow record batches. + + This method is eager and will execute the associated expression + immediately. + + Parameters + ---------- + expr + Ibis expression to export to pyarrow + limit + An integer to effect a specific row limit. A value of `None` means + "no limit". The default is in `ibis/config.py`. + params + Mapping of scalar parameter expressions to value. + chunk_size + Maximum number of rows in each returned record batch. + + Returns + ------- + RecordBatchReader + Collection of pyarrow `RecordBatch`s. + """ + pa = self._import_pyarrow() + + schema = expr.as_table().schema() + array_type = schema.as_struct().to_pyarrow() + arrays = ( + pa.array(map(tuple, batch), type=array_type) + for batch in self._cursor_batches( + expr, params=params, limit=limit, chunk_size=chunk_size + ) + ) + batches = map(pa.RecordBatch.from_struct_array, arrays) + + return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batches) diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index 687d21f83ed3..b7e8dbc9a83c 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -49,6 +49,16 @@ def __getitem__(self, key: str) -> partial: return getattr(self, key) +class VarGen: + __slots__ = () + + def __getattr__(self, name: str) -> sge.Var: + return sge.Var(this=name) + + def __getitem__(self, key: str) -> sge.Var: + return sge.Var(this=key) + + class FuncGen: __slots__ = ("namespace",) @@ -110,7 +120,7 @@ def parenthesize(op, arg): @public class SQLGlotCompiler(abc.ABC): - __slots__ = "agg", "f" + __slots__ = "agg", "f", "v" rewrites: tuple = ( empty_in_values_right_side, @@ -138,6 +148,7 @@ class SQLGlotCompiler(abc.ABC): def __init__(self) -> None: self.agg = AggGen(aggfunc=self._aggregate) self.f = FuncGen() + self.v = VarGen() @property @abc.abstractmethod @@ -258,14 +269,56 @@ def visit_Alias(self, op, *, arg, name): return arg @visit_node.register(ops.Literal) - def visit_Literal(self, op, *, value, dtype, **kw): + def visit_Literal(self, op, *, value, dtype): + """Compile a literal value. + + This is the default implementation for compiling literal values. + + Most backends should not need to override this method unless they want + to handle NULL literals as well as every other type of non-null literal + including integers, floating point numbers, decimals, strings, etc. + + The logic here is: + + 1. If the value is None and the type is nullable, return NULL + 1. If the value is None and the type is not nullable, raise an error + 1. Call `visit_NonNullLiteral` method. + 1. If the previous returns `None`, call `visit_DefaultLiteral` method + else return the result of the previous step. + """ if value is None: if dtype.nullable: return NULL if dtype.is_null() else self.cast(NULL, dtype) raise com.UnsupportedOperationError( f"Unsupported NULL for non-nullable type: {dtype!r}" ) - elif dtype.is_integer(): + else: + result = self.visit_NonNullLiteral(op, value=value, dtype=dtype) + if result is None: + return self.visit_DefaultLiteral(op, value=value, dtype=dtype) + return result + + def visit_NonNullLiteral(self, op, *, value, dtype): + """Compile a non-null literal differently than the default implementation. + + Most backends should implement this, but only when they need to handle + some non-null literal differently than the default implementation + (`visit_DefaultLiteral`). + + Return `None` from an override of this method to fall back to + `visit_DefaultLiteral`. + """ + return self.visit_DefaultLiteral(op, value=value, dtype=dtype) + + def visit_DefaultLiteral(self, op, *, value, dtype): + """Compile a literal with a non-null value. + + This is the default implementation for compiling non-null literals. + + Most backends should not need to override this method unless they want + to handle compiling every kind of non-null literal value. + """ + if dtype.is_integer(): return sge.convert(value) elif dtype.is_floating(): if math.isnan(value): @@ -274,7 +327,7 @@ def visit_Literal(self, op, *, value, dtype, **kw): return self.POS_INF if value < 0 else self.NEG_INF return sge.convert(value) elif dtype.is_decimal(): - return self.cast(sge.convert(str(value)), dtype) + return self.cast(str(value), dtype) elif dtype.is_interval(): return sge.Interval( this=sge.convert(str(value)), unit=dtype.resolution.upper() @@ -304,7 +357,7 @@ def visit_Literal(self, op, *, value, dtype, **kw): keys = self.f.array( *( self.visit_Literal( - ops.Literal(k, key_type), value=k, dtype=key_type, **kw + ops.Literal(k, key_type), value=k, dtype=key_type ) for k in value.keys() ) @@ -314,7 +367,7 @@ def visit_Literal(self, op, *, value, dtype, **kw): values = self.f.array( *( self.visit_Literal( - ops.Literal(v, value_type), value=v, dtype=value_type, **kw + ops.Literal(v, value_type), value=v, dtype=value_type ) for v in value.values() ) @@ -323,15 +376,14 @@ def visit_Literal(self, op, *, value, dtype, **kw): return self.f.map(keys, values) elif dtype.is_struct(): items = [ - sge.Slice( - this=sge.convert(k), - expression=self.visit_Literal( - ops.Literal(v, field_dtype), value=v, dtype=field_dtype, **kw - ), - ) + self.visit_Literal( + ops.Literal(v, field_dtype), value=v, dtype=field_dtype + ).as_(k, quoted=self.quoted) for field_dtype, (k, v) in zip(dtype.types, value.items()) ] return sge.Struct.from_arg_list(items) + elif dtype.is_uuid(): + return self.cast(str(value), dtype) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") @@ -403,14 +455,6 @@ def visit_Not(self, op, *, arg): ### Timey McTimeFace - @visit_node.register(ops.Date) - def visit_Date(self, op, *, arg): - return sge.Date(this=arg) - - @visit_node.register(ops.DateFromYMD) - def visit_DateFromYMD(self, op, *, year, month, day): - return sge.DateFromParts(year=year, month=month, day=day) - @visit_node.register(ops.Time) def visit_Time(self, op, *, arg): return self.cast(arg, to=dt.time) @@ -429,39 +473,39 @@ def visit_ExtractEpochSeconds(self, op, *, arg): @visit_node.register(ops.ExtractYear) def visit_ExtractYear(self, op, *, arg): - return self.f.extract("year", arg) + return self.f.extract(self.v.year, arg) @visit_node.register(ops.ExtractMonth) def visit_ExtractMonth(self, op, *, arg): - return self.f.extract("month", arg) + return self.f.extract(self.v.month, arg) @visit_node.register(ops.ExtractDay) def visit_ExtractDay(self, op, *, arg): - return self.f.extract("day", arg) + return self.f.extract(self.v.day, arg) @visit_node.register(ops.ExtractDayOfYear) def visit_ExtractDayOfYear(self, op, *, arg): - return self.f.extract("dayofyear", arg) + return self.f.extract(self.v.dayofyear, arg) @visit_node.register(ops.ExtractQuarter) def visit_ExtractQuarter(self, op, *, arg): - return self.f.extract("quarter", arg) + return self.f.extract(self.v.quarter, arg) @visit_node.register(ops.ExtractWeekOfYear) def visit_ExtractWeekOfYear(self, op, *, arg): - return self.f.extract("week", arg) + return self.f.extract(self.v.week, arg) @visit_node.register(ops.ExtractHour) def visit_ExtractHour(self, op, *, arg): - return self.f.extract("hour", arg) + return self.f.extract(self.v.hour, arg) @visit_node.register(ops.ExtractMinute) def visit_ExtractMinute(self, op, *, arg): - return self.f.extract("minute", arg) + return self.f.extract(self.v.minute, arg) @visit_node.register(ops.ExtractSecond) def visit_ExtractSecond(self, op, *, arg): - return self.f.extract("second", arg) + return self.f.extract(self.v.second, arg) @visit_node.register(ops.TimestampTruncate) @visit_node.register(ops.DateTruncate) @@ -479,11 +523,10 @@ def visit_TimestampTruncate(self, op, *, arg, unit): "us": "us", } - unit = unit.short - if (duckunit := unit_mapping.get(unit)) is None: + if (unit := unit_mapping.get(unit.short)) is None: raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") - return self.f.date_trunc(duckunit, arg) + return self.f.date_trunc(unit, arg) @visit_node.register(ops.DayOfWeekIndex) def visit_DayOfWeekIndex(self, op, *, arg): @@ -521,7 +564,6 @@ def visit_LStrip(self, op, *, arg): def visit_Substring(self, op, *, arg, start, length): if_pos = sge.Substring(this=arg, start=start + 1, length=length) if_neg = sge.Substring(this=arg, start=start, length=length) - return self.if_(start >= 0, if_pos, if_neg) @visit_node.register(ops.StringFind) @@ -538,18 +580,10 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.strpos(arg, substr) - @visit_node.register(ops.RegexSearch) - def visit_RegexSearch(self, op, *, arg, pattern): - return sge.RegexpLike(this=arg, expression=pattern, flag=sge.convert("s")) - @visit_node.register(ops.RegexReplace) def visit_RegexReplace(self, op, *, arg, pattern, replacement): return self.f.regexp_replace(arg, pattern, replacement, "g") - @visit_node.register(ops.RegexExtract) - def visit_RegexExtract(self, op, *, arg, pattern, index): - return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) - @visit_node.register(ops.StringConcat) def visit_StringConcat(self, op, *, arg): return self.f.concat(*arg) @@ -566,10 +600,6 @@ def visit_StringSQLLike(self, op, *, arg, pattern, escape): def visit_StringSQLILike(self, op, *, arg, pattern, escape): return arg.ilike(pattern) - @visit_node.register(ops.StringToTimestamp) - def visit_StringToTimestamp(self, op, *, arg, format_str): - return sge.StrToTime(this=arg, format=format_str) - ### NULL PLAYER CHARACTER @visit_node.register(ops.IsNull) def visit_IsNull(self, op, *, arg): @@ -583,12 +613,6 @@ def visit_NotNull(self, op, *, arg): def visit_InValues(self, op, *, value, options): return value.isin(*options) - ### Definitely Not Tensors - - @visit_node.register(ops.ArrayStringJoin) - def visit_ArrayStringJoin(self, op, *, sep, arg): - return self.f.array_to_string(arg, sep) - ### Counting @visit_node.register(ops.CountDistinct) @@ -667,15 +691,12 @@ def visit_Array(self, op, *, exprs): @visit_node.register(ops.StructColumn) def visit_StructColumn(self, op, *, names, values): return sge.Struct.from_arg_list( - [ - sge.Slice(this=sge.convert(name), expression=value) - for name, value in zip(names, values) - ] + [value.as_(name, quoted=self.quoted) for name, value in zip(names, values)] ) @visit_node.register(ops.StructField) def visit_StructField(self, op, *, arg, field): - return arg[sge.convert(field)] + return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted)) @visit_node.register(ops.IdenticalTo) def visit_IdenticalTo(self, op, *, left, right): @@ -695,10 +716,6 @@ def visit_Coalesce(self, op, *, arg): ### Ordering and window functions - @visit_node.register(ops.RowNumber) - def visit_RowNumber(self, op): - return sge.RowNumber() - @visit_node.register(ops.SortKey) def visit_SortKey(self, op, *, expr, ascending: bool): return sge.Ordered(this=expr, desc=not ascending) @@ -726,7 +743,7 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by): end_side = end.get("side", "FOLLOWING") spec = sge.WindowSpec( - kind=op.how.upper(), + kind=how.upper(), start=start_value, start_side=start_side, end=end_value, @@ -735,8 +752,14 @@ def visit_Window(self, op, *, how, func, start, end, group_by, order_by): ) order = sge.Order(expressions=order_by) if order_by else None + spec = self._minimize_spec(op.start, op.end, spec) + return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + @staticmethod + def _minimize_spec(start, end, spec): + return spec + @visit_node.register(ops.Lag) @visit_node.register(ops.Lead) def visit_LagLead(self, op, *, arg, offset, default): @@ -790,10 +813,6 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset): def visit_ArrayConcat(self, op, *, arg): return sge.ArrayConcat(this=arg[0], expressions=list(arg[1:])) - @visit_node.register(ops.ArrayContains) - def visit_ArrayContains(self, op, *, arg, other): - return sge.ArrayContains(this=arg, expression=other) - ## relations def _dedup_name( @@ -1094,22 +1113,18 @@ def visit_SQLStringView(self, op, *, query: str, name: str, child): def visit_SQLQueryResult(self, op, *, query, schema, source): return sg.parse_one(query, read=self.dialect).subquery() - @visit_node.register(ops.Unnest) - def visit_Unnest(self, op, *, arg): - return sge.Explode(this=arg) - - @visit_node.register(ops.RegexSplit) - def visit_RegexSplit(self, op, *, arg, pattern): - return sge.RegexpSplit(this=arg, expression=pattern) - - @visit_node.register(ops.Levenshtein) - def visit_Levenshtein(self, op, *, left, right): - return sge.Levenshtein(this=left, expression=right) - @visit_node.register(ops.JoinTable) def visit_JoinTable(self, op, *, parent, index): return parent + @visit_node.register(ops.Cast) + def visit_Cast(self, op, *, arg, to): + return self.cast(arg, to) + + @visit_node.register(ops.Value) + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError(type(op).__name__) + _SIMPLE_OPS = { ops.All: "bool_and", @@ -1117,7 +1132,6 @@ def visit_JoinTable(self, op, *, parent, index): ops.ArgMax: "max_by", ops.ArgMin: "min_by", ops.Power: "pow", - # Unary operations ops.IsNan: "isnan", ops.IsInf: "isinf", ops.Abs: "abs", @@ -1137,7 +1151,6 @@ def visit_JoinTable(self, op, *, parent, index): ops.Pi: "pi", ops.RandomScalar: "random", ops.Sign: "sign", - # Unary aggregates ops.ApproxCountDistinct: "approx_distinct", ops.Median: "median", ops.Mean: "avg", @@ -1152,14 +1165,12 @@ def visit_JoinTable(self, op, *, parent, index): ops.Any: "bool_or", ops.ArrayCollect: "array_agg", ops.GroupConcat: "group_concat", - # string operations ops.StringContains: "contains", ops.StringLength: "length", ops.Lowercase: "lower", ops.Uppercase: "upper", ops.StartsWith: "starts_with", ops.StrRight: "right", - # Other operations ops.IfElse: "if", ops.ArrayLength: "length", ops.NullIf: "nullif", @@ -1167,7 +1178,6 @@ def visit_JoinTable(self, op, *, parent, index): ops.Map: "map", ops.JSONGetItem: "json_extract", ops.ArrayFlatten: "flatten", - # common enough to be in the base, but not modeled in sqlglot ops.NTile: "ntile", ops.Degrees: "degrees", ops.Radians: "radians", @@ -1185,6 +1195,17 @@ def visit_JoinTable(self, op, *, parent, index): ops.StringReplace: "replace", ops.Reverse: "reverse", ops.StringSplit: "split", + ops.RegexSearch: "regexp_like", + ops.DateFromYMD: "datefromparts", + ops.Date: "date", + ops.RowNumber: "row_number", + ops.StringToTimestamp: "str_to_time", + ops.ArrayStringJoin: "array_to_string", + ops.Levenshtein: "levenshtein", + ops.Unnest: "explode", + ops.RegexSplit: "regexp_split", + ops.ArrayContains: "array_contains", + ops.RegexExtract: "regexp_extract", } _BINARY_INFIX_OPS = { diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index 46a7c996c996..9c4a7fe531df 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -274,7 +274,7 @@ def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType: assert dtype.unit is not None, "interval unit cannot be None" return sge.DataType( this=typecode.INTERVAL, - expressions=[sge.IntervalSpan(this=sge.Var(this=dtype.unit.name))], + expressions=[sge.Var(this=dtype.unit.name)], ) @classmethod diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index 44760bf8ed8e..7fdb5428f1d8 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -41,7 +41,19 @@ def _aggregate(self, funcname: str, *args, where): has_filter = where is not None func = self.f[funcname + "If" * has_filter] args += (where,) * has_filter - return func(*args) + + return func(*args, dialect=self.dialect) + + @staticmethod + def _minimize_spec(start, end, spec): + if ( + start is None + and isinstance(getattr(end, "value", None), ops.Literal) + and end.value.value == 0 + and end.following + ): + return None + return spec @singledispatchmethod def visit_node(self, op, **kw): @@ -223,11 +235,8 @@ def visit_IntervalFromInteger(self, op, *, arg, unit): ) return super().visit_node(op, arg=arg, unit=unit) - @visit_node.register(ops.Literal) - def visit_Literal(self, op, *, value, dtype, **kw): - if value is None: - return super().visit_node(op, value=value, dtype=dtype, **kw) - elif dtype.is_inet(): + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_inet(): v = str(value) return self.f.toIPv6(v) if ":" in v else self.f.toIPv4(v) elif dtype.is_string(): @@ -286,7 +295,7 @@ def visit_Literal(self, op, *, value, dtype, **kw): value_type = dtype.value_type values = [ self.visit_Literal( - ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw + ops.Literal(v, dtype=value_type), value=v, dtype=value_type ) for v in value ] @@ -303,7 +312,6 @@ def visit_Literal(self, op, *, value, dtype, **kw): ops.Literal(v, dtype=value_type), value=v, dtype=value_type, - **kw, ) ) @@ -311,13 +319,13 @@ def visit_Literal(self, op, *, value, dtype, **kw): elif dtype.is_struct(): fields = [ self.visit_Literal( - ops.Literal(v, dtype=field_type), value=v, dtype=field_type, **kw + ops.Literal(v, dtype=field_type), value=v, dtype=field_type ) for field_type, v in zip(dtype.types, value.values()) ] return self.f.tuple(*fields) else: - return super().visit_node(op, value=value, dtype=dtype, **kw) + return None @visit_node.register(ops.TimestampFromUNIX) def visit_TimestampFromUNIX(self, op, *, arg, unit): diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 3166c1c3b4d3..d82ad23e5597 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -73,11 +73,8 @@ def _to_timestamp(self, value, target_dtype, literal=False): def visit_node(self, op, **kw): return super().visit_node(op, **kw) - @visit_node.register(ops.Literal) - def visit_Literal(self, op, *, value, dtype, **kw): - if value is None: - return super().visit_node(op, value=value, dtype=dtype, **kw) - elif dtype.is_decimal(): + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_decimal(): return self.cast( sg.exp.convert(str(value)), dt.Decimal(precision=dtype.precision or 38, scale=dtype.scale or 9), @@ -106,7 +103,7 @@ def visit_Literal(self, op, *, value, dtype, **kw): elif dtype.is_binary(): return sg.exp.HexString(this=value.hex()) else: - return super().visit_node(op, value=value, dtype=dtype, **kw) + return None @visit_node.register(ops.Cast) def visit_Cast(self, op, *, arg, to): diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 44e4fe78c7f4..b25257e3979d 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -28,7 +28,6 @@ from ibis.backends.base import CanCreateSchema from ibis.backends.base.sqlglot import SQLGlotBackend from ibis.backends.base.sqlglot.compiler import STAR, C, F -from ibis.backends.base.sqlglot.datatypes import DuckDBType from ibis.backends.duckdb.compiler import DuckDBCompiler from ibis.backends.duckdb.datatypes import DuckDBPandasData from ibis.expr.operations.udf import InputType @@ -311,7 +310,7 @@ def get_schema( return sch.Schema( { - name: DuckDBType.from_string(typ, nullable=nullable) + name: self.compiler.type_mapper.from_string(typ, nullable=nullable) for name, typ, nullable in zip(names, types, nullables) } ) @@ -1394,7 +1393,10 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: for name, typ, null in zip( rows["column_name"], rows["column_type"], rows["null"] ): - yield name, DuckDBType.from_string(typ, nullable=null == "YES") + yield ( + name, + self.compiler.type_mapper.from_string(typ, nullable=null == "YES"), + ) def _register_in_memory_tables(self, expr: ir.Expr) -> None: for memtable in expr.op().find(ops.InMemoryTable): @@ -1434,10 +1436,10 @@ def _compile_udf(self, udf_node: ops.ScalarUDF) -> None: func = udf_node.__func__ name = func.__name__ input_types = [ - DuckDBType.to_string(param.annotation.pattern.dtype) + self.compiler.type_mapper.to_string(param.annotation.pattern.dtype) for param in udf_node.__signature__.parameters.values() ] - output_type = DuckDBType.to_string(udf_node.dtype) + output_type = self.compiler.type_mapper.to_string(udf_node.dtype) def register_udf(con): return con.create_function( diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index 6afe634d4280..86d60785895f 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -239,11 +239,8 @@ def visit_Cast(self, op, *, arg, to): return self.cast(arg, to) - @visit_node.register(ops.Literal) - def visit_Literal(self, op, *, value, dtype, **kw): - if value is None: - return super().visit_node(op, value=value, dtype=dtype, **kw) - elif dtype.is_interval(): + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_interval(): if dtype.unit.short == "ns": raise com.UnsupportedOperationError( f"{self.dialect} doesn't support nanosecond interval resolutions" @@ -288,7 +285,7 @@ def visit_Literal(self, op, *, value, dtype, **kw): return self.f[funcname](*args) else: - return super().visit_node(op, value=value, dtype=dtype, **kw) + return None @visit_node.register(ops.Capitalize) def visit_Capitalize(self, op, *, arg): @@ -340,6 +337,10 @@ def visit_TimestampNow(self, op): """DuckDB current timestamp defaults to timestamp + tz.""" return self.cast(super().visit_TimestampNow(op), dt.timestamp) + @visit_node.register(ops.RegexExtract) + def visit_RegexExtract(self, op, *, arg, pattern, index): + return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) + _SIMPLE_OPS = { ops.ArrayPosition: "list_indexof", diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index 1e28095c1ee2..7bfea9883fdd 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -308,7 +308,7 @@ def round_serieswise(arg, digits): ops.Pi: lambda: np.pi, ops.TimestampNow: lambda: pd.Timestamp("now", tz="UTC").tz_localize(None), ops.StringConcat: lambda xs: reduce(operator.add, xs), - ops.StringJoin: lambda sep, xs: reduce(lambda x, y: x + sep + y, xs), + ops.StringJoin: lambda xs, sep: reduce(lambda x, y: x + sep + y, xs), ops.Log: lambda x, base: np.log(x) if base is None else np.log(x) / np.log(base), } diff --git a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql index 6eaa105c4a49..ffa8c03c59cf 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql @@ -27,8 +27,8 @@ FROM ( FROM ( SELECT t1.field_of_study, - t1.__pivoted__['years'] AS years, - t1.__pivoted__['degrees'] AS degrees + t1.__pivoted__.years AS years, + t1.__pivoted__.degrees AS degrees FROM ( SELECT t0.field_of_study, @@ -72,8 +72,8 @@ FROM ( FROM ( SELECT t1.field_of_study, - t1.__pivoted__['years'] AS years, - t1.__pivoted__['degrees'] AS degrees + t1.__pivoted__.years AS years, + t1.__pivoted__.degrees AS degrees FROM ( SELECT t0.field_of_study, diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_filter_predicates/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_filter_predicates/out.sql index fb4bf6a1c3ff..918ff235ee2a 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_filter_predicates/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_filter_predicates/out.sql @@ -4,4 +4,4 @@ FROM t AS t0 WHERE LOWER(t0.color) LIKE '%de%' AND CONTAINS(LOWER(t0.color), 'de') - AND REGEXP_MATCHES(LOWER(t0.color), '.*ge.*', 's') \ No newline at end of file + AND REGEXP_MATCHES(LOWER(t0.color), '.*ge.*') \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql index 26824f377a3e..b8cfc5063ba5 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql @@ -30,5 +30,5 @@ FROM ( t3.month AS month_right FROM functional_alltypes AS t1 INNER JOIN functional_alltypes AS t3 - ON t1.tinyint_col < EXTRACT('minute' FROM t3.timestamp_col) + ON t1.tinyint_col < EXTRACT(minute FROM t3.timestamp_col) ) AS t5 \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql index aa75e2be0ae1..feacfd23da7e 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql @@ -5,7 +5,7 @@ SELECT FROM ( SELECT t11.region, - EXTRACT('year' FROM t11.odate) AS year, + EXTRACT(year FROM t11.odate) AS year, CAST(SUM(t11.amount) AS DOUBLE) AS total FROM ( SELECT @@ -28,7 +28,7 @@ FROM ( INNER JOIN ( SELECT t11.region, - EXTRACT('year' FROM t11.odate) AS year, + EXTRACT(year FROM t11.odate) AS year, CAST(SUM(t11.amount) AS DOUBLE) AS total FROM ( SELECT diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index be97ad419d92..d403f613146e 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1350,33 +1350,6 @@ def test_date_quantile(alltypes, func): assert result == date(2009, 12, 31) -@pytest.mark.parametrize( - ("result_fn", "expected_fn"), - [ - param( - lambda t, where, sep: ( - t.group_by("bigint_col") - .aggregate(tmp=lambda t: t.string_col.group_concat(sep, where=where)) - .order_by("bigint_col") - ), - lambda t, where, sep: ( - ( - t - if isinstance(where, slice) - else t.assign(string_col=t.string_col.where(where)) - ) - .groupby("bigint_col") - .string_col.agg( - lambda s: (np.nan if pd.isna(s).all() else sep.join(s.values)) - ) - .rename("tmp") - .sort_index() - .reset_index() - ), - id="group_concat", - ) - ], -) @pytest.mark.parametrize( ("ibis_sep", "pandas_sep"), [ @@ -1422,8 +1395,7 @@ def test_date_quantile(alltypes, func): ], ) @pytest.mark.notimpl( - ["datafusion", "polars", "mssql"], - raises=com.OperationNotDefinedError, + ["datafusion", "polars", "mssql"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl( ["druid"], @@ -1442,19 +1414,30 @@ def test_date_quantile(alltypes, func): reason='SQL parse failed. Encountered "group_concat ("', ) def test_group_concat( - backend, - alltypes, - df, - result_fn, - expected_fn, - ibis_cond, - pandas_cond, - ibis_sep, - pandas_sep, + backend, alltypes, df, ibis_cond, pandas_cond, ibis_sep, pandas_sep ): - expr = result_fn(alltypes, ibis_cond(alltypes), ibis_sep) + expr = ( + alltypes.group_by("bigint_col") + .aggregate( + tmp=lambda t: t.string_col.group_concat(ibis_sep, where=ibis_cond(t)) + ) + .order_by("bigint_col") + ) result = expr.execute() - expected = expected_fn(df, pandas_cond(df), pandas_sep) + expected = ( + ( + df + if isinstance(pandas_cond(df), slice) + else df.assign(string_col=df.string_col.where(pandas_cond(df))) + ) + .groupby("bigint_col") + .string_col.agg( + lambda s: (np.nan if pd.isna(s).all() else pandas_sep.join(s.values)) + ) + .rename("tmp") + .sort_index() + .reset_index() + ) backend.assert_frame_equal(result.fillna(pd.NA), expected.fillna(pd.NA)) diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index b7aa81c43dd1..7fc9696aa9ab 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -247,9 +247,7 @@ def test_scalar_param_date(backend, alltypes, value): "datafusion", "clickhouse", "polars", - "duckdb", "sqlite", - "snowflake", "impala", "oracle", "pyspark", diff --git a/ibis/backends/tests/test_uuid.py b/ibis/backends/tests/test_uuid.py index ea9064dd0d74..9a4dce517afa 100644 --- a/ibis/backends/tests/test_uuid.py +++ b/ibis/backends/tests/test_uuid.py @@ -25,6 +25,7 @@ "snowflake": "VARCHAR", "trino": "varchar(32)" if SQLALCHEMY2 else "uuid", "postgres": "uuid", + "clickhouse": "Nullable(UUID)", } UUID_EXPECTED_VALUES = { @@ -41,6 +42,7 @@ "oracle": TEST_UUID, "flink": TEST_UUID, "exasol": TEST_UUID, + "clickhouse": TEST_UUID, } pytestmark = pytest.mark.notimpl( @@ -64,7 +66,7 @@ raises=sqlalchemy.exc.NotSupportedError, ) @pytest.mark.notimpl( - ["impala", "datafusion", "polars", "clickhouse"], raises=NotImplementedError + ["impala", "datafusion", "polars"], raises=NotImplementedError ) @pytest.mark.notimpl( ["risingwave"], diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index f7cd6cdfcd80..ce9ffbb613c6 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -177,11 +177,6 @@ def calc_zscore(s): ["dask", "pandas", "polars"], raises=com.OperationNotDefinedError, ), - pytest.mark.notyet( - ["clickhouse"], - raises=ClickHouseDatabaseError, - reason="ClickHouse requires a specific window frame: unbounded preceding and unbounded following ONLY", - ), pytest.mark.broken( ["impala"], raises=AssertionError, diff --git a/ibis/backends/tests/tpch/conftest.py b/ibis/backends/tests/tpch/conftest.py index cfb85452841b..b2a88f9370a1 100644 --- a/ibis/backends/tests/tpch/conftest.py +++ b/ibis/backends/tests/tpch/conftest.py @@ -85,7 +85,14 @@ def wrapper(*args, backend, snapshot, **kwargs): assert not expected.empty assert len(expected) == len(result) - backend.assert_frame_equal(result, expected, check_dtype=False) + assert result.columns.tolist() == expected.columns.tolist() + for column in result.columns: + left = result.loc[:, column] + right = expected.loc[:, column] + assert ( + pytest.approx(left.values.tolist(), nan_ok=True) + == right.values.tolist() + ) # only write sql if the execution passes snapshot.assert_match(ibis_sql, sql_path_name) diff --git a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql index 35411472de9a..ea7f9f6eb7fe 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql @@ -25,7 +25,7 @@ FROM ( t6.l_shipdate, t6.l_extendedprice, t6.l_discount, - EXTRACT('year' FROM t6.l_shipdate) AS l_year, + EXTRACT(year FROM t6.l_shipdate) AS l_year, t6.l_extendedprice * ( CAST(1 AS TINYINT) - t6.l_discount ) AS volume diff --git a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/snowflake/h07.sql b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/snowflake/h07.sql index 48269d09259c..ce954992953d 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/snowflake/h07.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/snowflake/h07.sql @@ -25,7 +25,7 @@ FROM ( "t11"."l_shipdate", "t11"."l_extendedprice", "t11"."l_discount", - DATE_PART('year', "t11"."l_shipdate") AS "l_year", + DATE_PART(year, "t11"."l_shipdate") AS "l_year", "t11"."l_extendedprice" * ( 1 - "t11"."l_discount" ) AS "volume" diff --git a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql index 97b1be133851..99ba095e07ae 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql @@ -16,7 +16,7 @@ FROM ( CASE WHEN t23.nation = 'BRAZIL' THEN t23.volume ELSE CAST(0 AS TINYINT) END AS nation_volume FROM ( SELECT - EXTRACT('year' FROM t10.o_orderdate) AS o_year, + EXTRACT(year FROM t10.o_orderdate) AS o_year, t8.l_extendedprice * ( CAST(1 AS TINYINT) - t8.l_discount ) AS volume, diff --git a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/snowflake/h08.sql b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/snowflake/h08.sql index 8d25f3b2df17..e6b90d1f7a6e 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/snowflake/h08.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/snowflake/h08.sql @@ -16,7 +16,7 @@ FROM ( CASE WHEN "t30"."nation" = 'BRAZIL' THEN "t30"."volume" ELSE 0 END AS "nation_volume" FROM ( SELECT - DATE_PART('year', "t17"."o_orderdate") AS "o_year", + DATE_PART(year, "t17"."o_orderdate") AS "o_year", "t15"."l_extendedprice" * ( 1 - "t15"."l_discount" ) AS "volume", diff --git a/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/duckdb/h09.sql b/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/duckdb/h09.sql index 21489f03313d..b305db73e0ae 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/duckdb/h09.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/duckdb/h09.sql @@ -22,7 +22,7 @@ FROM ( ) - ( t8.ps_supplycost * t6.l_quantity ) AS amount, - EXTRACT('year' FROM t10.o_orderdate) AS o_year, + EXTRACT(year FROM t10.o_orderdate) AS o_year, t11.n_name AS nation, t9.p_name FROM lineitem AS t6 diff --git a/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/snowflake/h09.sql b/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/snowflake/h09.sql index a57563a10289..b828b08644bc 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/snowflake/h09.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h09/test_tpc_h09/snowflake/h09.sql @@ -22,7 +22,7 @@ FROM ( ) - ( "t14"."ps_supplycost" * "t12"."l_quantity" ) AS "amount", - DATE_PART('year', "t16"."o_orderdate") AS "o_year", + DATE_PART(year, "t16"."o_orderdate") AS "o_year", "t17"."n_name" AS "nation", "t15"."p_name" FROM ( diff --git a/ibis/expr/operations/strings.py b/ibis/expr/operations/strings.py index 9b40261b9d2a..bbaef67021c4 100644 --- a/ibis/expr/operations/strings.py +++ b/ibis/expr/operations/strings.py @@ -133,8 +133,8 @@ class FindInSet(Value): @public class StringJoin(Value): - sep: Value[dt.String] arg: VarTuple[Value[dt.String]] + sep: Value[dt.String] dtype = dt.string @@ -145,8 +145,8 @@ def shape(self): @public class ArrayStringJoin(Value): - sep: Value[dt.String] arg: Value[dt.Array[dt.String]] + sep: Value[dt.String] dtype = dt.string shape = rlz.shape_like("args") diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index add931bfeb2d..096aaad97417 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -356,7 +356,7 @@ def join(self, sep: str | ir.StringValue) -> ir.StringValue: -------- [`StringValue.join`](./expression-strings.qmd#ibis.expr.types.strings.StringValue.join) """ - return ops.ArrayStringJoin(sep, self).to_expr() + return ops.ArrayStringJoin(self, sep=sep).to_expr() def map(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: """Apply a `func` or `Deferred` to each element of this array expression. diff --git a/ibis/expr/types/strings.py b/ibis/expr/types/strings.py index ce44800760a0..629123cab09e 100644 --- a/ibis/expr/types/strings.py +++ b/ibis/expr/types/strings.py @@ -864,7 +864,7 @@ def join(self, strings: Sequence[str | StringValue] | ir.ArrayValue) -> StringVa cls = ops.ArrayStringJoin else: cls = ops.StringJoin - return cls(self, strings).to_expr() + return cls(strings, sep=self).to_expr() def startswith(self, start: str | StringValue) -> ir.BooleanValue: """Determine whether `self` starts with `end`. diff --git a/ibis/formats/pandas.py b/ibis/formats/pandas.py index d1983b1b58a8..e202b48f0621 100644 --- a/ibis/formats/pandas.py +++ b/ibis/formats/pandas.py @@ -112,6 +112,8 @@ def infer_table(cls, df, schema=None): return sch.Schema.from_tuples(pairs) + concat = staticmethod(pd.concat) + @classmethod def convert_table(cls, df, schema): if len(schema) != len(df.columns): @@ -122,7 +124,7 @@ def convert_table(cls, df, schema): columns = [] for (_, series), dtype in zip(df.items(), schema.types): columns.append(cls.convert_column(series, dtype)) - df = pd.concat(columns, axis=1) + df = cls.concat(columns, axis=1) # return data with the schema's columns which may be different than the # input columns @@ -395,6 +397,12 @@ def convert(value): class DaskData(PandasData): + @staticmethod + def concat(*args, **kwargs): + import dask.dataframe as dd + + return dd.concat(*args, **kwargs) + @classmethod def infer_column(cls, s): return PyArrowData.infer_column(s.compute())