diff --git a/ibis/config.py b/ibis/config.py index c2ab73123b53..c8c455c6246b 100644 --- a/ibis/config.py +++ b/ibis/config.py @@ -110,6 +110,8 @@ class Repr(Config): The maximum number of expression nodes to print when repring. table_columns : int The number of columns to show in leaf table expressions. + table_rows : int + The number of rows to show for in memory tables. query_text_length : int The maximum number of characters to show in the `query` field repr of SQLQueryResult operations. @@ -121,6 +123,7 @@ class Repr(Config): depth: Optional[PosInt] = None table_columns: Optional[PosInt] = None + table_rows: PosInt = 10 query_text_length: PosInt = 80 show_types: bool = False interactive: Interactive = Interactive() diff --git a/ibis/expr/format.py b/ibis/expr/format.py index 4f792e39e7bd..ea4bec11ec87 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -1,684 +1,341 @@ from __future__ import annotations -import collections import functools +import itertools import textwrap -import types # noqa: TCH003 -from typing import Any, Callable, Deque, Iterable, Mapping, Tuple +import types +from typing import Mapping, Sequence import rich.pretty +from public import public import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops -import ibis.expr.schema as sch import ibis.expr.types as ir from ibis import util -from ibis.common import graph -Aliases = Mapping[ops.TableNode, int] -Deps = Deque[Tuple[int, ops.TableNode]] +_infix_ops = { + # comparison operations + ops.Equals: "==", + ops.IdenticalTo: "===", + ops.NotEquals: "!=", + ops.Less: "<", + ops.LessEqual: "<=", + ops.Greater: ">", + ops.GreaterEqual: ">=", + # arithmetic operations + ops.Add: "+", + ops.Subtract: "-", + ops.Multiply: "*", + ops.Divide: "/", + ops.FloorDivide: "//", + ops.Modulus: "%", + ops.Power: "**", + # temporal operations + ops.DateAdd: "+", + ops.DateSub: "-", + ops.DateDiff: "-", + ops.TimeAdd: "+", + ops.TimeSub: "-", + ops.TimeDiff: "-", + ops.TimestampAdd: "+", + ops.TimestampSub: "-", + ops.TimestampDiff: "-", + ops.IntervalAdd: "+", + ops.IntervalSubtract: "-", + ops.IntervalMultiply: "*", + ops.IntervalFloorDivide: "//", + # boolean operators + ops.And: "&", + ops.Or: "|", + ops.Xor: "^", +} -class Alias: - __slots__ = ("value",) +def type_info(datatype) -> str: + """Format `datatype` for display next to a column.""" + return f" # {datatype}" * ibis.options.repr.show_types - def __init__(self, value: int) -> None: - self.value = value - def __str__(self) -> str: - return f"r{self.value}" +def truncate(pieces: Sequence[str], limit: int) -> list[str]: + if limit < 1: + raise ValueError("limit must be >= 1") + elif limit == 1: + return pieces[-1:] + elif limit >= len(pieces): + return pieces + + first_n = limit // 2 + last_m = limit - first_n + first, last = pieces[:first_n], pieces[-last_m:] + + maxlen = max(*map(len, first), *map(len, last)) + ellipsis = util.VERTICAL_ELLIPSIS.center(maxlen) + + return [*first, ellipsis, *last] + + +def render(obj, indent_level=0, limit_items=None, key_separator=":"): + if isinstance(obj, Mapping): + rendered = {f"{k}{key_separator}": render(v) for k, v in obj.items() if v} + if not rendered: + return "" + maxlen = max(map(len, rendered.keys())) + lines = [f"{k:<{maxlen}} {v}" for k, v in rendered.items()] + if limit_items is not None: + lines = truncate(lines, limit_items) + result = "\n".join(lines) + elif util.is_iterable(obj): + lines = tuple(render(item) for item in obj) + if limit_items is not None: + lines = truncate(lines, limit_items) + result = "\n".join(lines) + else: + result = str(obj) + return util.indent(result, spaces=indent_level * 2) -def fmt(expr: ir.Expr) -> str: - """Format `expr`. - Main entry point for the `Expr.__repr__` implementation. +def render_fields(fields, indent_level=0, limit_items=None): + rendered = {k: render(v, 1) for k, v in fields.items() if v} + lines = [f"{k}:\n{v}" for k, v in rendered.items()] + if limit_items is not None: + lines = truncate(lines, limit_items) + result = "\n".join(lines) + return util.indent(result, spaces=indent_level * 2) - Returns - ------- - str - Formatted expression - """ - *deps, root = graph.toposort(expr.op()).keys() - deps = collections.deque( - (Alias(alias), dep) - for alias, dep in enumerate( - dep for dep in deps if isinstance(dep, ops.TableNode) - ) - ) - aliases = {dep: alias for alias, dep in deps} - pieces = [] - - while deps: - alias, node = deps.popleft() - formatted = fmt_table_op(node, aliases=aliases, deps=deps) - pieces.append(f"{alias} := {formatted}") - - name = expr.get_name() if expr.has_name() else None - pieces.append(fmt_root(root, name=name, aliases=aliases, deps=deps)) - depth = ibis.options.repr.depth or 0 - if depth and depth < len(pieces): - return fmt_truncated(pieces, depth=depth) - return "\n\n".join(pieces) - - -def fmt_truncated( - pieces: Iterable[str], - *, - depth: int, - sep: str = "\n\n", - ellipsis: str = util.VERTICAL_ELLIPSIS, -) -> str: - if depth == 1: - return pieces[-1] - - first_n = depth // 2 - last_m = depth - first_n - return sep.join([*pieces[:first_n], ellipsis, *pieces[-last_m:]]) - - -def selection_maxlen(nodes: Iterable[ops.Node]) -> int: - """Compute the length of the longest name of input expressions.""" - return max( - (len(node.name) for node in nodes if isinstance(node, ops.Named)), default=0 - ) +def render_schema(schema, indent_level=0, limit_items=None): + if limit_items is None: + limit_items = ibis.options.repr.table_columns + return render(schema, indent_level, limit_items, key_separator="") -@functools.singledispatch -def fmt_root(op: ops.Node, *, aliases: Aliases, **_: Any) -> str: - """Fallback formatting implementation.""" - raw_parts = fmt_fields( - op, - dict.fromkeys(op.argnames, fmt_value), - aliases=aliases, - ) - return f"{op.__class__.__name__}\n{raw_parts}" +def inline(obj): + if isinstance(obj, Mapping): + fields = ", ".join(f"{k!r}: {inline(v)}" for k, v in obj.items()) + return f"{{{fields}}}" + elif util.is_iterable(obj): + elems = ", ".join(inline(item) for item in obj) + return f"[{elems}]" + elif isinstance(obj, types.FunctionType): + return obj.__name__ + elif isinstance(obj, dt.DataType): + return str(obj) + else: + return repr(obj) -@fmt_root.register -def _fmt_root_table_node(op: ops.TableNode, **kwargs: Any) -> str: - return fmt_table_op(op, **kwargs) +def inline_args(fields, prefer_positional=False): + fields = {k: inline(v) for k, v in fields.items() if v} + if fields and prefer_positional: + first, *rest = fields.keys() + if not rest: + return fields[first] + elif first in {"arg", "expr"}: + first = fields[first] + rest = (f"{k}={fields[k]}" for k in rest) + return ", ".join((first, *rest)) -@fmt_root.register -def _fmt_root_value_op(op: ops.Value, *, name: str, aliases: Aliases, **_: Any) -> str: - value = fmt_value(op, aliases=aliases) - prefix = f"{name}: " if name is not None else "" - return f"{prefix}{value}{type_info(op.to_expr().type())}" + return ", ".join(f"{k}={v}" for k, v in fields.items()) -@fmt_root.register -def _fmt_root_literal_op( - op: ops.Literal, *, name: str, aliases: Aliases, **_: Any -) -> str: - value = fmt_value(op, aliases=aliases) - return f"{value}{type_info(op.to_expr().type())}" +class Rendered(str): + def __repr__(self): + return self -@fmt_root.register(ops.SortKey) -def _fmt_root_sort_key(op: ops.SortKey, *, aliases: Aliases, **_: Any) -> str: - return fmt_value(op, aliases=aliases) +@public +def pretty(node): + if isinstance(node, ir.Expr): + node = node.op() + elif not isinstance(node, ops.Node): + raise TypeError(f"Expected an expression , got {type(node)}") + refcnt = itertools.count() + tables = {} -@functools.singledispatch -def fmt_table_op(op: ops.TableNode, **_: Any) -> str: - raise AssertionError(f"`fmt_table_op` not implemented for operation: {type(op)}") - - -@fmt_table_op.register -def _fmt_table_op_physical_table(op: ops.PhysicalTable, **_: Any) -> str: - top = f"{op.__class__.__name__}: {op.name}" - formatted_schema = fmt_schema(op.schema) - return f"{top}\n{formatted_schema}" - - -def fmt_schema(schema: sch.Schema) -> str: - """Format `schema`. - - Parameters - ---------- - schema - Ibis schema to format - - Returns - ------- - str - Formatted schema - """ - names = schema.names - maxlen = max(map(len, names)) - cols = [f"{name:<{maxlen}} {typ}" for name, typ in schema.items()] - depth = ibis.options.repr.table_columns - if depth is not None and depth < len(cols): - first_column_name = names[0] - raw = fmt_truncated( - cols, - depth=depth, - sep="\n", - ellipsis=util.VERTICAL_ELLIPSIS.center(len(first_column_name)), - ) - else: - raw = "\n".join(cols) + def mapper(op, _, **kwargs): + result = fmt(op, **kwargs) + if isinstance(op, ops.Relation): + tables[op] = result + result = f"r{next(refcnt)}" + return Rendered(result) - return util.indent(raw, spaces=2) + results = node.map(mapper) + out = [] + for table, rendered in tables.items(): + if table is not node: + ref = results[table] + out.append(f"{ref} := {rendered}") -@fmt_table_op.register -def _fmt_table_op_sql_query_result(op: ops.SQLQueryResult, **_: Any) -> str: - short_query = textwrap.shorten( - op.query, - ibis.options.repr.query_text_length, - placeholder=f" {util.HORIZONTAL_ELLIPSIS}", - ) - query = f"query: {short_query!r}" - top = op.__class__.__name__ - formatted_schema = fmt_schema(op.schema) - schema_field = util.indent(f"schema:\n{formatted_schema}", spaces=2) - return f"{top}\n{util.indent(query, spaces=2)}\n{schema_field}" - - -@fmt_table_op.register -def _fmt_table_op_view(op: ops.View, *, aliases: Aliases, **_: Any) -> str: - top = op.__class__.__name__ - formatted_schema = fmt_schema(op.schema) - schema_field = util.indent(f"schema:\n{formatted_schema}", spaces=2) - return f"{top}[{aliases[op.child]}]: {op.name}\n{schema_field}" - - -@fmt_table_op.register -def _fmt_table_op_sql_view( - op: ops.SQLStringView, - *, - aliases: Aliases, - **_: Any, -) -> str: - short_query = textwrap.shorten( - op.query, - ibis.options.repr.query_text_length, - placeholder=f" {util.HORIZONTAL_ELLIPSIS}", - ) - query = f"query: {short_query!r}" - top = op.__class__.__name__ - formatted_schema = fmt_schema(op.schema) - schema_field = util.indent(f"schema:\n{formatted_schema}", spaces=2) - components = [ - f"{top}[{aliases[op.child]}]: {op.name}", - util.indent(query, spaces=2), - schema_field, - ] - return "\n".join(components) + res = results[node] + if isinstance(node, ops.Literal): + out.append(res) + elif isinstance(node, ops.Value): + out.append(f"{node.name}: {res}{type_info(node.dtype)}") + elif isinstance(node, ops.Relation): + out.append(tables[node]) + + return "\n\n".join(out) @functools.singledispatch -def fmt_join(op: ops.Join, *, aliases: Aliases) -> tuple[str, str]: - raise AssertionError(f"join type {type(op)} not implemented") - - -@fmt_join.register(ops.Join) -def _fmt_join(op: ops.Join, *, aliases: Aliases) -> tuple[str, str]: - # format the operator and its relation inputs - left = aliases[op.left] - right = aliases[op.right] - top = f"{op.__class__.__name__}[{left}, {right}]" - - # format the join predicates - # if only one, put it directly after the join on the same line - # if more than one put each on a separate line - preds = op.predicates - formatted_preds = [fmt_value(pred, aliases=aliases) for pred in preds] - has_one_pred = len(preds) == 1 - sep = " " if has_one_pred else "\n" - joined_predicates = util.indent( - "\n".join(formatted_preds), - spaces=2 * (not has_one_pred), - ) - trailing_sep = "\n" + "\n" * (not has_one_pred) - return f"{top}{sep}{joined_predicates}", trailing_sep - - -@fmt_join.register(ops.AsOfJoin) -def _fmt_asof_join(op: ops.AsOfJoin, *, aliases: Aliases) -> tuple[str, str]: - left = aliases[op.left] - right = aliases[op.right] - top = f"{op.__class__.__name__}[{left}, {right}]" - raw_parts = fmt_fields( - op, - dict(predicates=fmt_value, by=fmt_value, tolerance=fmt_value), - aliases=aliases, - ) - return f"{top}\n{raw_parts}", "\n\n" - - -@fmt_table_op.register -def _fmt_table_op_join( - op: ops.Join, - *, - aliases: Aliases, - deps: Deps, - **_: Any, -) -> str: - # first, format the current join operation - result, join_sep = fmt_join(op, aliases=aliases) - formatted_joins = [result, join_sep] - - # process until the first non-Join dependency is popped in other words - # process all runs of joins - alias, current = None, None - if deps: - alias, current = deps.popleft() - - while isinstance(current, ops.Join): - # copy the alias so that mutations to the value aren't shared - # format the `current` join - formatted_join, join_sep = fmt_join(current, aliases=aliases) - formatted_joins.append(f"{alias} := {formatted_join}") - formatted_joins.append(join_sep) - - if not deps: - break - - alias, current = deps.popleft() - - if current is not None and not isinstance(current, ops.Join): - # the last node popped from `deps` isn't a join which means we - # still need to process it, so we put it at the front of the queue - deps.appendleft((alias, current)) - - # we don't want the last trailing separator so remove it from the end - formatted_joins.pop() - return "".join(formatted_joins) - - -@fmt_table_op.register -def _(op: ops.CrossJoin, *, aliases: Aliases, **_: Any) -> str: - left = aliases[op.left] - right = aliases[op.right] - return f"{op.__class__.__name__}[{left}, {right}]" - - -def _fmt_set_op( - op: ops.SetOp, - *, - aliases: Aliases, - distinct: bool | None = None, -) -> str: - args = [str(aliases[op.left]), str(aliases[op.right])] - if distinct is not None: - args.append(f"distinct={distinct}") - return f"{op.__class__.__name__}[{', '.join(args)}]" +def fmt(op, **kwargs): + raise NotImplementedError(f"no pretty printer for {type(op)}") -@fmt_table_op.register -def _fmt_table_op_set_op(op: ops.SetOp, *, aliases: Aliases, **_: Any) -> str: - return _fmt_set_op(op, aliases=aliases) - - -@fmt_table_op.register -def _fmt_table_op_union(op: ops.Union, *, aliases: Aliases, **_: Any) -> str: - return _fmt_set_op(op, aliases=aliases, distinct=op.distinct) - - -@fmt_table_op.register(ops.SelfReference) -@fmt_table_op.register(ops.Distinct) -def _fmt_table_op_self_reference_distinct( - op: ops.Distinct | ops.SelfReference, - *, - aliases: Aliases, - **_: Any, -) -> str: - return f"{op.__class__.__name__}[{aliases[op.table]}]" - - -@fmt_table_op.register -def _fmt_table_op_fillna(op: ops.FillNa, *, aliases: Aliases, **_: Any) -> str: - top = f"{op.__class__.__name__}[{aliases[op.table]}]" - raw_parts = fmt_fields(op, dict(replacements=fmt_value), aliases=aliases) - return f"{top}\n{raw_parts}" - - -@fmt_table_op.register -def _fmt_table_op_dropna(op: ops.DropNa, *, aliases: Aliases, **_: Any) -> str: - top = f"{op.__class__.__name__}[{aliases[op.table]}]" - how = f"how: {op.how!r}" - raw_parts = fmt_fields(op, dict(subset=fmt_value), aliases=aliases) - return f"{top}\n{util.indent(how, spaces=2)}\n{raw_parts}" - - -def fmt_fields( - op: ops.TableNode, - fields: Mapping[str, Callable[[Any, Aliases], str]], - *, - aliases: Aliases, -) -> str: - parts = [] - - for field, formatter in fields.items(): - if exprs := [ - expr for expr in util.promote_list(getattr(op, field)) if expr is not None - ]: - field_fmt = [formatter(expr, aliases=aliases) for expr in exprs] - - parts.append(f"{field}:") - parts.append(util.indent("\n".join(field_fmt), spaces=2)) - - return util.indent("\n".join(parts), spaces=2) - - -@fmt_table_op.register -def _fmt_table_op_selection(op: ops.Selection, *, aliases: Aliases, **_: Any) -> str: - top = f"{op.__class__.__name__}[{aliases[op.table]}]" - raw_parts = fmt_fields( - op, - dict( - selections=functools.partial( - fmt_selection_column, - maxlen=selection_maxlen(op.selections), - ), - predicates=fmt_value, - sort_keys=fmt_value, - ), - aliases=aliases, - ) - return f"{top}\n{raw_parts}" - - -@fmt_table_op.register -def _fmt_table_op_aggregation( - op: ops.Aggregation, *, aliases: Aliases, **_: Any -) -> str: - top = f"{op.__class__.__name__}[{aliases[op.table]}]" - raw_parts = fmt_fields( - op, - dict( - metrics=functools.partial( - fmt_selection_column, - maxlen=selection_maxlen(op.metrics), - ), - by=functools.partial( - fmt_selection_column, - maxlen=selection_maxlen(op.by), - ), - having=fmt_value, - predicates=fmt_value, - sort_keys=fmt_value, - ), - aliases=aliases, - ) - return f"{top}\n{raw_parts}" - - -@fmt_table_op.register -def _fmt_table_op_limit(op: ops.Limit, *, aliases: Aliases, **_: Any) -> str: - params = [str(aliases[op.table]), f"n={op.n:d}"] - if offset := op.offset: - params.append(f"offset={offset:d}") - return f"{op.__class__.__name__}[{', '.join(params)}]" - - -@fmt_table_op.register -def _fmt_table_op_in_memory_table(op: ops.InMemoryTable, **_: Any) -> str: - # arbitrary limit, but some value is needed to avoid a huge repr - max_length = 10 - pretty_data = rich.pretty.pretty_repr(op.data, max_length=max_length) - return "\n".join( - [ - op.__class__.__name__, - util.indent("data:", spaces=2), - util.indent(pretty_data, spaces=4), - ] - ) +@fmt.register(ops.Relation) +@fmt.register(ops.DummyTable) +def _relation(op, **kwargs): + schema = render_schema(op.schema, indent_level=1) + return f"{op.__class__.__name__}\n{schema}" -@fmt_table_op.register -def _fmt_table_op_dummy_table(op: ops.DummyTable, **_: Any) -> str: - formatted_schema = fmt_schema(op.schema) - schema_field = util.indent(f"schema:\n{formatted_schema}", spaces=2) - return f"{op.__class__.__name__}\n{schema_field}" +@fmt.register(ops.PhysicalTable) +def _physical_table(op, name, **kwargs): + schema = render_schema(op.schema, indent_level=1) + return f"{op.__class__.__name__}: {name}\n{schema}" -@functools.singledispatch -def fmt_selection_column(value_expr: object, **_: Any) -> str: - raise AssertionError( - f"expression type not implemented for fmt_selection_column: {type(value_expr)}" +@fmt.register(ops.InMemoryTable) +def _in_memory_table(op, data, **kwargs): + name = f"{op.__class__.__name__}\n" + data = rich.pretty.pretty_repr(op.data, max_length=ibis.options.repr.table_columns) + return name + render_fields({"data": data}, 1) + + +@fmt.register(ops.SQLQueryResult) +@fmt.register(ops.SQLStringView) +def _sql_query_result(op, query, **kwargs): + clsname = op.__class__.__name__ + if isinstance(op, ops.SQLStringView): + child, name = kwargs["child"], kwargs["name"] + top = f"{clsname}[{child}]: {name}\n" + else: + top = f"{clsname}\n" + + query = textwrap.shorten( + query, + width=ibis.options.repr.query_text_length, + placeholder=f" {util.HORIZONTAL_ELLIPSIS}", ) + schema = render_schema(op.schema) + return top + render_fields({"query": query, "schema": schema}, 1) -def type_info(datatype: dt.DataType) -> str: - """Format `datatype` for display next to a column.""" - return f" # {datatype}" * ibis.options.repr.show_types +@fmt.register(ops.FillNa) +@fmt.register(ops.DropNa) +def _fill_na(op, table, **kwargs): + name = f"{op.__class__.__name__}[{table}]\n" + return name + render_fields(kwargs, 1) -@fmt_selection_column.register -def _fmt_selection_column_sequence(node: tuple, **kwargs): - return "\n".join(fmt_selection_column(value, **kwargs) for value in node.values) +@fmt.register(ops.Aggregation) +def _aggregation(op, table, **kwargs): + name = f"{op.__class__.__name__}[{table}]\n" + kwargs["by"] = {node.name: r for node, r in zip(op.by, kwargs["by"])} + kwargs["metrics"] = {node.name: r for node, r in zip(op.metrics, kwargs["metrics"])} + return name + render_fields(kwargs, 1) -@fmt_selection_column.register -def _fmt_selection_column_value_expr( - node: ops.Value, *, aliases: Aliases, maxlen: int = 0 -) -> str: - name = f"{node.name}:" - # the additional 1 is for the colon - aligned_name = f"{name:<{maxlen + 1}}" - value = fmt_value(node, aliases=aliases) - dtype = type_info(node.dtype) - return f"{aligned_name} {value}{dtype}" +@fmt.register(ops.Selection) +def _selection(op, table, selections, **kwargs): + name = f"{op.__class__.__name__}[{table}]\n" + # special handling required to support both relation and value selections + rels, values = [], {} + for node, rendered in zip(op.selections, selections): + if isinstance(node, ops.Relation): + rels.append(rendered) + else: + values[node.name] = f"{rendered}{type_info(node.dtype)}" -@fmt_selection_column.register -def _fmt_selection_column_table_expr( - node: ops.TableNode, *, aliases: Aliases, **_: Any -) -> str: - return str(aliases[node]) + segments = filter(None, [render(rels), render(values)]) + kwargs["selections"] = "\n".join(segments) + return name + render_fields(kwargs, 1) -_BIN_OP_CHARS = { - # comparison operations - ops.Equals: "==", - ops.IdenticalTo: "===", - ops.NotEquals: "!=", - ops.Less: "<", - ops.LessEqual: "<=", - ops.Greater: ">", - ops.GreaterEqual: ">=", - # arithmetic - ops.Add: "+", - ops.Subtract: "-", - ops.Multiply: "*", - ops.Divide: "/", - ops.FloorDivide: "//", - ops.Modulus: "%", - ops.Power: "**", - # temporal operations - ops.DateAdd: "+", - ops.DateSub: "-", - ops.DateDiff: "-", - ops.TimeAdd: "+", - ops.TimeSub: "-", - ops.TimeDiff: "-", - ops.TimestampAdd: "+", - ops.TimestampSub: "-", - ops.TimestampDiff: "-", - ops.IntervalAdd: "+", - ops.IntervalSubtract: "-", - ops.IntervalMultiply: "*", - ops.IntervalFloorDivide: "//", - # boolean operators - ops.And: "&", - ops.Or: "|", - ops.Xor: "^", -} +@fmt.register(ops.SetOp) +def _set_op(op, left, right, distinct): + args = [str(left), str(right)] + if op.distinct is not None: + args.append(f"distinct={distinct}") + return f"{op.__class__.__name__}[{', '.join(args)}]" -@functools.singledispatch -def fmt_value(obj, **_: Any) -> str: - """Format a value expression or operation. - [`repr`][repr] the object if we don't have a specific formatting - rule. - """ - return repr(obj) +@fmt.register(ops.Join) +def _join(op, left, right, predicates, **kwargs): + args = [str(left), str(right)] + name = f"{op.__class__.__name__}[{', '.join(args)}]" + if len(predicates) == 1: + # if only one, put it directly after the join on the same line + top = f"{name} {predicates[0]}" + fields = kwargs + else: + top = f"{name}" + fields = {"predicates": predicates, **kwargs} -@fmt_value.register -def _fmt_value_function_type(func: types.FunctionType, **_: Any) -> str: - return func.__name__ + fields = render_fields(fields, 1) + return f"{top}\n{fields}" if fields else top -@fmt_value.register -def _fmt_value_node(op: ops.Node, **_: Any) -> str: - raise AssertionError(f"`fmt_value` not implemented for operation: {type(op)}") +@fmt.register(ops.Limit) +def _limit(op, table, **kwargs): + params = inline_args(kwargs) + return f"{op.__class__.__name__}[{table}, {params}]" -@fmt_value.register -def _fmt_value_sequence(op: tuple, **kwargs: Any) -> str: - return ", ".join([fmt_value(value, **kwargs) for value in op]) +@fmt.register(ops.SelfReference) +@fmt.register(ops.Distinct) +def _self_reference(op, table, **kwargs): + return f"{op.__class__.__name__}[{table}]" -@fmt_value.register -def _fmt_value_expr(op: ops.Value, *, aliases: Aliases) -> str: - """Format a value expression. +@fmt.register(ops.Literal) +def _literal(op, value, **kwargs): + if op.dtype.is_interval(): + return f"{value!r} {op.dtype.unit.short}" + else: + return f"{value!r}" - Forwards the call on to the specific operation dispatch rule. - """ - return fmt_value(op, aliases=aliases) +@fmt.register(ops.TableColumn) +def _table_column(op, table, name): + return f"{table}.{name}" -@fmt_value.register -def _fmt_value_binary_op(op: ops.Binary, *, aliases: Aliases) -> str: - left = fmt_value(op.left, aliases=aliases) - right = fmt_value(op.right, aliases=aliases) + +@fmt.register(ops.Value) +def _value(op, **kwargs): + fields = inline_args(kwargs, prefer_positional=True) + return f"{op.__class__.__name__}({fields})" + + +@fmt.register(ops.Alias) +def _alias(op, arg, name): + return arg + + +@fmt.register(ops.Binary) +def _binary(op, left, right): try: - op_char = _BIN_OP_CHARS[type(op)] + symbol = _infix_ops[op.__class__] except KeyError: - return f"{type(op).__name__}({left}, {right})" + return f"{op.__class__.__name__}({left}, {right})" else: - return f"{left} {op_char} {right}" + return f"{left} {symbol} {right}" -@fmt_value.register -def _fmt_value_negate(op: ops.Negate, *, aliases: Aliases) -> str: - op_name = "Not" if op.dtype.is_boolean() else "Negate" - operand = fmt_value(op.arg, aliases=aliases) - return f"{op_name}({operand})" +@fmt.register(ops.ScalarParameter) +def _scalar_parameter(op, dtype, **kwargs): + return f"$({dtype})" -@fmt_value.register -def _fmt_value_literal(op: ops.Literal, **_: Any) -> str: - if op.dtype.is_interval(): - return f"{op.value} {op.dtype.unit.short}" - return repr(op.value) - - -@fmt_value.register -def _fmt_value_datatype(datatype: dt.DataType, **_: Any) -> str: - return str(datatype) - - -@fmt_value.register -def _fmt_value_value_op(op: ops.Value, *, aliases: Aliases) -> str: - args = [] - # loop over argument names and original expression - for argname, orig_expr in zip(op.argnames, op.args): - # promote argument to a list, so that we don't accidentally repr - # entire subtrees when all we want is the formatted argument value - if exprs := [expr for expr in util.promote_list(orig_expr) if expr is not None]: - # format the individual argument values - formatted_args = ", ".join( - fmt_value(expr, aliases=aliases) for expr in exprs - ) - # if the original argument was a non-string iterable, display it as - # a list - value = ( - f"[{formatted_args}]" if util.is_iterable(orig_expr) else formatted_args - ) - # `arg` and `expr` are noisy, so we ignore printing them as a - # special case - if argname not in ("arg", "expr"): - formatted = f"{argname}={value}" - else: - formatted = value - args.append(formatted) - - return f"{op.__class__.__name__}({', '.join(args)})" - - -@fmt_value.register -def _fmt_value_alias(op: ops.Alias, *, aliases: Aliases) -> str: - return fmt_value(op.arg, aliases=aliases) - - -@fmt_value.register -def _fmt_value_table_column(op: ops.TableColumn, *, aliases: Aliases) -> str: - return f"{aliases[op.table]}.{op.name}" - - -@fmt_value.register -def _fmt_value_scalar_parameter(op: ops.ScalarParameter, **_: Any) -> str: - return f"$({op.dtype})" - - -@fmt_value.register -def _fmt_value_sort_key(op: ops.SortKey, *, aliases: Aliases) -> str: - expr = fmt_value(op.expr, aliases=aliases) - prefix = "asc" if op.ascending else "desc" - return f"{prefix} {expr}" - - -@fmt_value.register -def _fmt_value_physical_table(op: ops.PhysicalTable, **_: Any) -> str: - """Format a table as value. - - This function is called when a table is used in a value expression. - An example is `table.count()`. - """ - return op.name - - -@fmt_value.register -def _fmt_value_table_node(op: ops.TableNode, *, aliases: Aliases, **_: Any) -> str: - """Format a table as value. - - This function is called when a table is used in a value expression. - An example is `table.count()`. - """ - return f"{aliases[op]}" - - -@fmt_value.register -def _fmt_value_string_sql_like(op: ops.StringSQLLike, *, aliases: Aliases) -> str: - expr = fmt_value(op.arg, aliases=aliases) - pattern = fmt_value(op.pattern, aliases=aliases) - prefix = "I" * isinstance(op, ops.StringSQLILike) - return f"{expr} {prefix}LIKE {pattern}" - - -@fmt_value.register -def _fmt_value_window(win: ops.WindowFrame, *, aliases: Aliases) -> str: - args = [] - for field, value in ( - ("group_by", win.group_by), - ("order_by", win.order_by), - ("start", win.start), - ("end", win.end), - ): - disp_field = field.lstrip("_") - if value is not None: - if isinstance(value, tuple): - # don't show empty sequences - if not value: - continue - elements = ", ".join( - fmt_value( - arg.op() if isinstance(arg, ir.Expr) else arg, - aliases=aliases, - ) - for arg in value - ) - formatted = f"[{elements}]" - else: - formatted = fmt_value(value, aliases=aliases) - args.append(f"{disp_field}={formatted}") - return f"{win.__class__.__name__}({', '.join(args)})" +@fmt.register(ops.SortKey) +def _sort_key(op, expr, **kwargs): + return f"{'asc' if op.ascending else 'desc'} {expr}" diff --git a/ibis/expr/tests/conftest.py b/ibis/expr/tests/conftest.py new file mode 100644 index 000000000000..fdbbc8c80490 --- /dev/null +++ b/ibis/expr/tests/conftest.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +import pytest + +import ibis + +MOCK_TABLES = { + "alltypes": [ + ("a", "int8"), + ("b", "int16"), + ("c", "int32"), + ("d", "int64"), + ("e", "float32"), + ("f", "float64"), + ("g", "string"), + ("h", "boolean"), + ("i", "timestamp"), + ("j", "date"), + ("k", "time"), + ], + "foo_t": [("key1", "string"), ("key2", "string"), ("value1", "double")], + "bar_t": [("key1", "string"), ("key2", "string"), ("value2", "double")], + "t1": [("key1", "string"), ("key2", "string"), ("value1", "double")], + "t2": [("key1", "string"), ("key2", "string"), ("value2", "double")], + "t3": [("key2", "string"), ("key3", "string"), ("value3", "double")], + "t4": [("key3", "string"), ("value4", "double")], + "bar": [("x", "double"), ("job", "string")], + "foo": [ + ("job", "string"), + ("dept_id", "string"), + ("year", "int32"), + ("y", "double"), + ], + "star1": [ + ("c", "int32"), + ("f", "double"), + ("foo_id", "string"), + ("bar_id", "string"), + ], + "star2": [ + ("foo_id", "string"), + ("value1", "double"), + ("value3", "double"), + ], + "star3": [("bar_id", "string"), ("value2", "double")], + "test1": [("c", "int32"), ("f", "double"), ("g", "string")], + "test2": [("key", "string"), ("value", "double")], + "tpch_region": [ + ("r_regionkey", "int16"), + ("r_name", "string"), + ("r_comment", "string"), + ], + "tpch_nation": [ + ("n_nationkey", "int16"), + ("n_name", "string"), + ("n_regionkey", "int16"), + ("n_comment", "string"), + ], + "tpch_lineitem": [ + ("l_orderkey", "int64"), + ("l_partkey", "int64"), + ("l_suppkey", "int64"), + ("l_linenumber", "int32"), + ("l_quantity", "decimal(12,2)"), + ("l_extendedprice", "decimal(12,2)"), + ("l_discount", "decimal(12,2)"), + ("l_tax", "decimal(12,2)"), + ("l_returnflag", "string"), + ("l_linestatus", "string"), + ("l_shipdate", "string"), + ("l_commitdate", "string"), + ("l_receiptdate", "string"), + ("l_shipinstruct", "string"), + ("l_shipmode", "string"), + ("l_comment", "string"), + ], + "tpch_customer": [ + ("c_custkey", "int64"), + ("c_name", "string"), + ("c_address", "string"), + ("c_nationkey", "int16"), + ("c_phone", "string"), + ("c_acctbal", "decimal"), + ("c_mktsegment", "string"), + ("c_comment", "string"), + ], + "tpch_orders": [ + ("o_orderkey", "int64"), + ("o_custkey", "int64"), + ("o_orderstatus", "string"), + ("o_totalprice", "decimal(12,2)"), + ("o_orderdate", "string"), + ("o_orderpriority", "string"), + ("o_clerk", "string"), + ("o_shippriority", "int32"), + ("o_comment", "string"), + ], + "functional_alltypes": [ + ("id", "int32"), + ("bool_col", "boolean"), + ("tinyint_col", "int8"), + ("smallint_col", "int16"), + ("int_col", "int32"), + ("bigint_col", "int64"), + ("float_col", "float32"), + ("double_col", "float64"), + ("date_string_col", "string"), + ("string_col", "string"), + ("timestamp_col", "timestamp"), + ("year", "int32"), + ("month", "int32"), + ], + "airlines": [ + ("year", "int32"), + ("month", "int32"), + ("day", "int32"), + ("dayofweek", "int32"), + ("dep_time", "int32"), + ("crs_dep_time", "int32"), + ("arr_time", "int32"), + ("crs_arr_time", "int32"), + ("carrier", "string"), + ("flight_num", "int32"), + ("tail_num", "int32"), + ("actual_elapsed_time", "int32"), + ("crs_elapsed_time", "int32"), + ("airtime", "int32"), + ("arrdelay", "int32"), + ("depdelay", "int32"), + ("origin", "string"), + ("dest", "string"), + ("distance", "int32"), + ("taxi_in", "int32"), + ("taxi_out", "int32"), + ("cancelled", "int32"), + ("cancellation_code", "string"), + ("diverted", "int32"), + ("carrier_delay", "int32"), + ("weather_delay", "int32"), + ("nas_delay", "int32"), + ("security_delay", "int32"), + ("late_aircraft_delay", "int32"), + ], + "tpcds_customer": [ + ("c_customer_sk", "int64"), + ("c_customer_id", "string"), + ("c_current_cdemo_sk", "int32"), + ("c_current_hdemo_sk", "int32"), + ("c_current_addr_sk", "int32"), + ("c_first_shipto_date_sk", "int32"), + ("c_first_sales_date_sk", "int32"), + ("c_salutation", "string"), + ("c_first_name", "string"), + ("c_last_name", "string"), + ("c_preferred_cust_flag", "string"), + ("c_birth_day", "int32"), + ("c_birth_month", "int32"), + ("c_birth_year", "int32"), + ("c_birth_country", "string"), + ("c_login", "string"), + ("c_email_address", "string"), + ("c_last_review_date", "string"), + ], + "tpcds_customer_address": [ + ("ca_address_sk", "bigint"), + ("ca_address_id", "string"), + ("ca_street_number", "string"), + ("ca_street_name", "string"), + ("ca_street_type", "string"), + ("ca_suite_number", "string"), + ("ca_city", "string"), + ("ca_county", "string"), + ("ca_state", "string"), + ("ca_zip", "string"), + ("ca_country", "string"), + ("ca_gmt_offset", "decimal(5,2)"), + ("ca_location_type", "string"), + ], + "tpcds_customer_demographics": [ + ("cd_demo_sk", "bigint"), + ("cd_gender", "string"), + ("cd_marital_status", "string"), + ("cd_education_status", "string"), + ("cd_purchase_estimate", "int"), + ("cd_credit_rating", "string"), + ("cd_dep_count", "int"), + ("cd_dep_employed_count", "int"), + ("cd_dep_college_count", "int"), + ], + "tpcds_date_dim": [ + ("d_date_sk", "bigint"), + ("d_date_id", "string"), + ("d_date", "string"), + ("d_month_seq", "int"), + ("d_week_seq", "int"), + ("d_quarter_seq", "int"), + ("d_year", "int"), + ("d_dow", "int"), + ("d_moy", "int"), + ("d_dom", "int"), + ("d_qoy", "int"), + ("d_fy_year", "int"), + ("d_fy_quarter_seq", "int"), + ("d_fy_week_seq", "int"), + ("d_day_name", "string"), + ("d_quarter_name", "string"), + ("d_holiday", "string"), + ("d_weekend", "string"), + ("d_following_holiday", "string"), + ("d_first_dom", "int"), + ("d_last_dom", "int"), + ("d_same_day_ly", "int"), + ("d_same_day_lq", "int"), + ("d_current_day", "string"), + ("d_current_week", "string"), + ("d_current_month", "string"), + ("d_current_quarter", "string"), + ("d_current_year", "string"), + ], + "tpcds_household_demographics": [ + ("hd_demo_sk", "bigint"), + ("hd_income_band_sk", "int"), + ("hd_buy_potential", "string"), + ("hd_dep_count", "int"), + ("hd_vehicle_count", "int"), + ], + "tpcds_item": [ + ("i_item_sk", "bigint"), + ("i_item_id", "string"), + ("i_rec_start_date", "string"), + ("i_rec_end_date", "string"), + ("i_item_desc", "string"), + ("i_current_price", "decimal(7,2)"), + ("i_wholesale_cost", "decimal(7,2)"), + ("i_brand_id", "int"), + ("i_brand", "string"), + ("i_class_id", "int"), + ("i_class", "string"), + ("i_category_id", "int"), + ("i_category", "string"), + ("i_manufact_id", "int"), + ("i_manufact", "string"), + ("i_size", "string"), + ("i_formulation", "string"), + ("i_color", "string"), + ("i_units", "string"), + ("i_container", "string"), + ("i_manager_id", "int"), + ("i_product_name", "string"), + ], + "tpcds_promotion": [ + ("p_promo_sk", "bigint"), + ("p_promo_id", "string"), + ("p_start_date_sk", "int"), + ("p_end_date_sk", "int"), + ("p_item_sk", "int"), + ("p_cost", "decimal(15,2)"), + ("p_response_target", "int"), + ("p_promo_name", "string"), + ("p_channel_dmail", "string"), + ("p_channel_email", "string"), + ("p_channel_catalog", "string"), + ("p_channel_tv", "string"), + ("p_channel_radio", "string"), + ("p_channel_press", "string"), + ("p_channel_event", "string"), + ("p_channel_demo", "string"), + ("p_channel_details", "string"), + ("p_purpose", "string"), + ("p_discount_active", "string"), + ], + "tpcds_store": [ + ("s_store_sk", "bigint"), + ("s_store_id", "string"), + ("s_rec_start_date", "string"), + ("s_rec_end_date", "string"), + ("s_closed_date_sk", "int"), + ("s_store_name", "string"), + ("s_number_employees", "int"), + ("s_floor_space", "int"), + ("s_hours", "string"), + ("s_manager", "string"), + ("s_market_id", "int"), + ("s_geography_class", "string"), + ("s_market_desc", "string"), + ("s_market_manager", "string"), + ("s_division_id", "int"), + ("s_division_name", "string"), + ("s_company_id", "int"), + ("s_company_name", "string"), + ("s_street_number", "string"), + ("s_street_name", "string"), + ("s_street_type", "string"), + ("s_suite_number", "string"), + ("s_city", "string"), + ("s_county", "string"), + ("s_state", "string"), + ("s_zip", "string"), + ("s_country", "string"), + ("s_gmt_offset", "decimal(5,2)"), + ("s_tax_precentage", "decimal(5,2)"), + ], + "tpcds_store_sales": [ + ("ss_sold_time_sk", "bigint"), + ("ss_item_sk", "bigint"), + ("ss_customer_sk", "bigint"), + ("ss_cdemo_sk", "bigint"), + ("ss_hdemo_sk", "bigint"), + ("ss_addr_sk", "bigint"), + ("ss_store_sk", "bigint"), + ("ss_promo_sk", "bigint"), + ("ss_ticket_number", "int"), + ("ss_quantity", "int"), + ("ss_wholesale_cost", "decimal(7,2)"), + ("ss_list_price", "decimal(7,2)"), + ("ss_sales_price", "decimal(7,2)"), + ("ss_ext_discount_amt", "decimal(7,2)"), + ("ss_ext_sales_price", "decimal(7,2)"), + ("ss_ext_wholesale_cost", "decimal(7,2)"), + ("ss_ext_list_price", "decimal(7,2)"), + ("ss_ext_tax", "decimal(7,2)"), + ("ss_coupon_amt", "decimal(7,2)"), + ("ss_net_paid", "decimal(7,2)"), + ("ss_net_paid_inc_tax", "decimal(7,2)"), + ("ss_net_profit", "decimal(7,2)"), + ("ss_sold_date_sk", "bigint"), + ], + "tpcds_time_dim": [ + ("t_time_sk", "bigint"), + ("t_time_id", "string"), + ("t_time", "int"), + ("t_hour", "int"), + ("t_minute", "int"), + ("t_second", "int"), + ("t_am_pm", "string"), + ("t_shift", "string"), + ("t_sub_shift", "string"), + ("t_meal_time", "string"), + ], +} + + +def table(name): + return ibis.table(name=name, schema=MOCK_TABLES[name]) + + +@pytest.fixture +def alltypes(): + return table("alltypes") diff --git a/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt new file mode 100644 index 000000000000..44b15ca820f4 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt @@ -0,0 +1,20 @@ +r0 := UnboundTable: alltypes + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time + +Aggregation[r0] + metrics: + c: Sum(r0.c) + d: Mean(r0.d) + by: + key1: r0.g + key2: Round(r0.f) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_argument_repr_shows_name/repr.txt b/ibis/expr/tests/snapshots/test_format/test_argument_repr_shows_name/repr.txt new file mode 100644 index 000000000000..f29279a518eb --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_argument_repr_shows_name/repr.txt @@ -0,0 +1,4 @@ +r0 := UnboundTable: fakename2 + fakecolname1 int64 + +NullIf(fakecolname1, 2): NullIf(r0.fakecolname1, null_if_expr=2) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt new file mode 100644 index 000000000000..aeaba71bcc34 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt @@ -0,0 +1,20 @@ +r0 := UnboundTable: right + time2 int32 + value2 float64 + +r1 := UnboundTable: left + time1 int32 + value float64 + +r2 := AsOfJoin[r1, r0] r1.time1 == r0.time2 + +r3 := InnerJoin[r2, r0] r1.value == r0.value2 + +Selection[r3] + selections: + time1: r2.time1 + value: r2.value + time2: r2.time2 + value2: r2.value2 + time2_right: r0.time2 + value2_right: r0.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt b/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt new file mode 100644 index 000000000000..0f9d5621fa4b --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt @@ -0,0 +1,20 @@ +r0 := UnboundTable: t + a int64 + +r1 := Selection[r0] + predicates: + r0.a < 42 + r0.a >= 42 + +r2 := Selection[r1] + selections: + r1 + x: r1.a + 42 + +r3 := Aggregation[r2] + metrics: + y: Sum(r2.a) + by: + x: r2.x + +Limit[r3, n=10] \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_destruct_selection/repr.txt b/ibis/expr/tests/snapshots/test_format/test_destruct_selection/repr.txt new file mode 100644 index 000000000000..013871ecfb27 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_destruct_selection/repr.txt @@ -0,0 +1,7 @@ +r0 := UnboundTable: t + col int64 + +Aggregation[r0] + metrics: + sum: StructField(ReductionVectorizedUDF(func=multi_output_udf, func_args=[r0.col], input_type=[int64], return_type={'sum': int64, 'mean': float64}), field='sum') + mean: StructField(ReductionVectorizedUDF(func=multi_output_udf, func_args=[r0.col], input_type=[int64], return_type={'sum': int64, 'mean': float64}), field='mean') \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_dict_repr.txt b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_dict_repr.txt new file mode 100644 index 000000000000..960ac1160204 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_dict_repr.txt @@ -0,0 +1,7 @@ +r0 := UnboundTable: t + a int64 + b string + +FillNa[r0] + replacements: + a: 3 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_int_repr.txt b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_int_repr.txt new file mode 100644 index 000000000000..d7aa4f2ee692 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_int_repr.txt @@ -0,0 +1,11 @@ +r0 := UnboundTable: t + a int64 + b string + +r1 := Selection[r0] + selections: + a: r0.a + +FillNa[r1] + replacements: + 3 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_str_repr.txt b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_str_repr.txt new file mode 100644 index 000000000000..887edd9ee5b9 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_str_repr.txt @@ -0,0 +1,11 @@ +r0 := UnboundTable: t + a int64 + b string + +r1 := Selection[r0] + selections: + b: r0.b + +FillNa[r1] + replacements: + 'foo' \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt new file mode 100644 index 000000000000..168803538ebe --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt @@ -0,0 +1,2 @@ +DummyTable + foo array \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_in_memory_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_in_memory_table/repr.txt new file mode 100644 index 000000000000..4f3fe692b69e --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_format_in_memory_table/repr.txt @@ -0,0 +1,9 @@ +r0 := InMemoryTable + data: + PandasDataFrameProxy: + x y + 0 1 2 + 1 3 4 + 2 5 6 + +Add(Sum(x), Sum(y)): Sum(r0.x) + Sum(r0.y) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt new file mode 100644 index 000000000000..057e2d8c8966 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt @@ -0,0 +1,36 @@ +r0 := UnboundTable: three + bar_id string + value2 float64 + +r1 := UnboundTable: one + c int32 + f float64 + foo_id string + bar_id string + +r2 := UnboundTable: two + foo_id string + value1 float64 + +r3 := Selection[r1] + predicates: + r1.f > 0 + +r4 := LeftJoin[r3, r2] r3.foo_id == r2.foo_id + +r5 := Selection[r4] + selections: + c: r3.c + f: r3.f + foo_id: r3.foo_id + bar_id: r3.bar_id + foo_id_right: r2.foo_id + value1: r2.value1 + +r6 := InnerJoin[r5, r0] r3.bar_id == r0.bar_id + +Selection[r6] + selections: + r3 + value1: r2.value1 + value2: r0.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_new_relational_operation/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_new_relational_operation/repr.txt new file mode 100644 index 000000000000..f058f8b462d4 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_format_new_relational_operation/repr.txt @@ -0,0 +1,30 @@ +r0 := UnboundTable: alltypes + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time + +r1 := MyRelation + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time + +Selection[r1] + selections: + r1 + a2: r1.a \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_projection/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_projection/repr.txt new file mode 100644 index 000000000000..c982128f1c0e --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_format_projection/repr.txt @@ -0,0 +1,20 @@ +r0 := UnboundTable: alltypes + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time + +r1 := Selection[r0] + selections: + c: r0.c + a: r0.a + f: r0.f + +a: r1.a \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_table_column/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_table_column/repr.txt new file mode 100644 index 000000000000..1770e3f85056 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_format_table_column/repr.txt @@ -0,0 +1,14 @@ +r0 := UnboundTable: alltypes + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time + +f: r0.f \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt new file mode 100644 index 000000000000..cfd72d2fff7c --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt @@ -0,0 +1,18 @@ +r0 := UnboundTable: airlines + dest string + origin string + arrdelay int32 + +r1 := Aggregation[r0] + metrics: + Mean(arrdelay): Mean(r0.arrdelay) + by: + dest: r0.dest + predicates: + InValues(value=r0.dest, options=['ORD', 'JFK', 'SFO']) + +r2 := Selection[r1] + sort_keys: + desc r1.Mean(arrdelay) + +Limit[r2, n=10] \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt new file mode 100644 index 000000000000..0648ffe1b86d --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt @@ -0,0 +1,27 @@ +r0 := UnboundTable: purchases + region string + kind string + user int64 + amount float64 + +r1 := Aggregation[r0] + metrics: + total: Sum(r0.amount) + by: + region: r0.region + kind: r0.kind + +r2 := Selection[r1] + predicates: + r1.kind == 'foo' + +r3 := Selection[r1] + predicates: + r1.kind == 'bar' + +r4 := InnerJoin[r2, r3] r2.region == r3.region + +Selection[r4] + selections: + r2 + right_total: r3.total \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_named_value_expr_show_name/repr.txt b/ibis/expr/tests/snapshots/test_format/test_named_value_expr_show_name/repr.txt new file mode 100644 index 000000000000..15e8a5a6278b --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_named_value_expr_show_name/repr.txt @@ -0,0 +1,14 @@ +r0 := UnboundTable: alltypes + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time + +Multiply(f, 2): r0.f * 2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_named_value_expr_show_name/repr2.txt b/ibis/expr/tests/snapshots/test_format/test_named_value_expr_show_name/repr2.txt new file mode 100644 index 000000000000..4df5900a9d11 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_named_value_expr_show_name/repr2.txt @@ -0,0 +1,14 @@ +r0 := UnboundTable: alltypes + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time + +baz: r0.f * 2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_repr_exact/repr.txt b/ibis/expr/tests/snapshots/test_format/test_repr_exact/repr.txt new file mode 100644 index 000000000000..38e341469f0c --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_repr_exact/repr.txt @@ -0,0 +1,9 @@ +r0 := UnboundTable: t + col int64 + col2 string + col3 float64 + +Selection[r0] + selections: + r0 + col4: StringLength(r0.col2) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_same_column_multiple_aliases/repr.txt b/ibis/expr/tests/snapshots/test_format/test_same_column_multiple_aliases/repr.txt new file mode 100644 index 000000000000..1826aa9d8567 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_same_column_multiple_aliases/repr.txt @@ -0,0 +1,7 @@ +r0 := UnboundTable: t + col int64 + +Selection[r0] + selections: + fakealias1: r0.col + fakealias2: r0.col \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr1.txt b/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr1.txt new file mode 100644 index 000000000000..2c77398383f6 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr1.txt @@ -0,0 +1,2 @@ +UnboundTable: t + t string \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr8.txt b/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr8.txt new file mode 100644 index 000000000000..77b63a386216 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr8.txt @@ -0,0 +1,10 @@ +UnboundTable: t + a string + b string + c string + d string + ⋮ + q string + r string + s string + t string \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr_all.txt b/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr_all.txt new file mode 100644 index 000000000000..a6bef8768240 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_schema_truncation/repr_all.txt @@ -0,0 +1,21 @@ +UnboundTable: t + a string + b string + c string + d string + e string + f string + g string + h string + i string + j string + k string + l string + m string + n string + o string + p string + q string + r string + s string + t string \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/cnt_repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/cnt_repr.txt new file mode 100644 index 000000000000..a85e2bdb5dbb --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/cnt_repr.txt @@ -0,0 +1,5 @@ +r0 := UnboundTable: t1 + a int64 + b float64 + +CountStar(t1): CountStar(r0) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt new file mode 100644 index 000000000000..e63b05c8c635 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt @@ -0,0 +1,17 @@ +r0 := UnboundTable: t1 + a int64 + b float64 + +r1 := UnboundTable: t2 + a int64 + b float64 + +r2 := InnerJoin[r0, r1] r0.a == r1.a + +r3 := Selection[r2] + selections: + a: r0.a + b: r0.b + b_right: r1.b + +CountStar(): CountStar(r3) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/union_repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/union_repr.txt new file mode 100644 index 000000000000..caab7a357ba4 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/union_repr.txt @@ -0,0 +1,16 @@ +r0 := UnboundTable: t1 + a int64 + b float64 + +r1 := UnboundTable: t2 + a int64 + b float64 + +r2 := Union[r0, r1, distinct=False] + +r3 := Selection[r2] + selections: + a: r2.a + b: r2.b + +CountStar(): CountStar(r3) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_type_output/repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_type_output/repr.txt new file mode 100644 index 000000000000..1f92b51ba906 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_table_type_output/repr.txt @@ -0,0 +1,9 @@ +r0 := UnboundTable: foo + job string + dept_id string + year int32 + y float64 + +r1 := SelfReference[r0] + +Equals(dept_id, dept_id): r0.dept_id == r1.dept_id \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt new file mode 100644 index 000000000000..959d15672b18 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt @@ -0,0 +1,25 @@ +r0 := UnboundTable: right + time2 int32 + value2 float64 + b string + +r1 := UnboundTable: left + time1 int32 + value float64 + a string + +r2 := InnerJoin[r1, r0] r1.a == r0.b + +r3 := InnerJoin[r2, r0] r1.value == r0.value2 + +Selection[r3] + selections: + time1: r2.time1 + value: r2.value + a: r2.a + time2: r2.time2 + value2: r2.value2 + b: r2.b + time2_right: r0.time2 + value2_right: r0.value2 + b_right: r0.b \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt b/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt new file mode 100644 index 000000000000..a062b06c486d --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt @@ -0,0 +1,5 @@ +r0 := UnboundTable: t + a int64 + b string + +Mean(a): WindowFunction(func=Mean(r0.a), frame=RowsWindowFrame(table=r0, group_by=[r0.b])) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt b/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt new file mode 100644 index 000000000000..b46b19f10b7a --- /dev/null +++ b/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt @@ -0,0 +1,5 @@ +r0 := UnboundTable: t + a int64 + b string + +Mean(a): WindowFunction(func=Mean(r0.a), frame=RowsWindowFrame(table=r0, start=WindowBoundary(value=0, preceding=True))) \ No newline at end of file diff --git a/ibis/tests/expr/test_format.py b/ibis/expr/tests/test_format.py similarity index 53% rename from ibis/tests/expr/test_format.py rename to ibis/expr/tests/test_format.py index e2da6217429b..0de6e22b8579 100644 --- a/ibis/tests/expr/test_format.py +++ b/ibis/expr/tests/test_format.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import string import pytest @@ -9,22 +10,39 @@ import ibis.expr.format import ibis.expr.operations as ops import ibis.legacy.udf.vectorized as udf +from ibis import util from ibis.expr.operations.relations import Projection +# easier to switch implementation if needed +fmt = repr -def test_format_table_column(table): + +@pytest.mark.parametrize("cls", set(ops.Relation.__subclasses__()) - {Projection}) +def test_tables_have_format_rules(cls): + assert cls in ibis.expr.format.fmt.registry + + +@pytest.mark.parametrize("cls", [ops.PhysicalTable, ops.Relation]) +def test_tables_have_format_value_rules(cls): + assert cls in ibis.expr.format.fmt.registry + + +def test_format_table_column(alltypes, snapshot): # GH #507 - result = repr(table.f) + result = fmt(alltypes.f) assert "float64" in result + snapshot.assert_match(result, "repr.txt") -def test_format_projection(table): +def test_format_projection(alltypes, snapshot): # This should produce a ref to the projection - proj = table[["c", "a", "f"]] - repr(proj["a"]) + proj = alltypes[["c", "a", "f"]] + expr = proj["a"] + result = fmt(expr) + snapshot.assert_match(result, "repr.txt") -def test_table_type_output(): +def test_table_type_output(snapshot): foo = ibis.table( [ ("job", "string"), @@ -36,27 +54,28 @@ def test_table_type_output(): ) expr = foo.dept_id == foo.view().dept_id - result = repr(expr) - + result = fmt(expr) assert "SelfReference[r0]" in result assert "UnboundTable: foo" in result + snapshot.assert_match(result, "repr.txt") -def test_aggregate_arg_names(table): +def test_aggregate_arg_names(alltypes, snapshot): # Not sure how to test this *well* - - t = table + t = alltypes by_exprs = [t.g.name("key1"), t.f.round().name("key2")] metrics = [t.c.sum().name("c"), t.d.mean().name("d")] expr = t.group_by(by_exprs).aggregate(metrics) - result = repr(expr) + result = fmt(expr) assert "metrics" in result assert "by" in result + snapshot.assert_match(result, "repr.txt") + -def test_format_multiple_join_with_projection(): +def test_format_multiple_join_with_projection(snapshot): # Star schema with fact table table = ibis.table( [ @@ -84,30 +103,11 @@ def test_format_multiple_join_with_projection(): view = j2[[filtered, table2["value1"], table3["value2"]]] # it works! - repr(view) - - -def test_memoize_database_table(con): - table = con.table("test1") - table2 = con.table("test2") + result = fmt(view) + snapshot.assert_match(result, "repr.txt") - filter_pred = table["f"] > 0 - table3 = table[filter_pred] - join_pred = table3["g"] == table2["key"] - - joined = table2.inner_join(table3, [join_pred]) - - met1 = (table3["f"] - table2["value"]).mean().name("foo") - result = joined.aggregate( - [met1, table3["f"].sum().name("bar")], by=[table3["g"], table2["key"]] - ) - formatted = repr(result) - assert formatted.count("test1") == 1 - assert formatted.count("test2") == 1 - - -def test_memoize_filtered_table(): +def test_memoize_filtered_table(snapshot): airlines = ibis.table( [("dest", "string"), ("origin", "string"), ("arrdelay", "int32")], "airlines", @@ -117,38 +117,28 @@ def test_memoize_filtered_table(): t = airlines[airlines.dest.isin(dests)] delay_filter = t.dest.topk(10, by=t.arrdelay.mean()) - result = repr(delay_filter) + result = fmt(delay_filter) assert result.count("Selection") == 1 + snapshot.assert_match(result, "repr.txt") -def test_memoize_insert_sort_key(con): - table = con.table("airlines") - - t = table["arrdelay", "dest"] - expr = t.group_by("dest").mutate( - dest_avg=t.arrdelay.mean(), dev=t.arrdelay - t.arrdelay.mean() - ) - - worst = expr[expr.dev.notnull()].order_by(ibis.desc("dev")).limit(10) - - result = repr(worst) - - assert result.count("airlines") == 1 - -def test_named_value_expr_show_name(table): - expr = table.f * 2 +def test_named_value_expr_show_name(alltypes, snapshot): + expr = alltypes.f * 2 expr2 = expr.name("baz") # it works! - repr(expr) - - result2 = repr(expr2) + result = fmt(expr) + result2 = fmt(expr2) + assert "baz" not in result assert "baz" in result2 + snapshot.assert_match(result, "repr.txt") + snapshot.assert_match(result2, "repr2.txt") -def test_memoize_filtered_tables_in_join(): + +def test_memoize_filtered_tables_in_join(snapshot): # related: GH #667 purchases = ibis.table( [ @@ -169,20 +159,24 @@ def test_memoize_filtered_tables_in_join(): cond = left.region == right.region joined = left.join(right, cond)[left, right.total.name("right_total")] - result = repr(joined) + result = fmt(joined) # one for each aggregation # joins are shown without the word `predicates` above them # since joins only have predicates as arguments assert result.count("predicates") == 2 + snapshot.assert_match(result, "repr.txt") + -def test_argument_repr_shows_name(): +def test_argument_repr_shows_name(snapshot): t = ibis.table([("fakecolname1", "int64")], name="fakename2") expr = t.fakecolname1.nullif(2) - result = repr(expr) + result = fmt(expr) + assert "fakecolname1" in result assert "fakename2" in result + snapshot.assert_match(result, "repr.txt") def test_scalar_parameter_formatting(): @@ -193,23 +187,24 @@ def test_scalar_parameter_formatting(): assert str(value) == "my_param: $(int64)" -def test_same_column_multiple_aliases(): +def test_same_column_multiple_aliases(snapshot): table = ibis.table([("col", "int64")], name="t") expr = table[table.col.name("fakealias1"), table.col.name("fakealias2")] - result = repr(expr) + result = fmt(expr) assert "UnboundTable: t" in result assert "col int64" in result assert "fakealias1: r0.col" in result assert "fakealias2: r0.col" in result + snapshot.assert_match(result, "repr.txt") def test_scalar_parameter_repr(): value = ibis.param(dt.timestamp).name("value") - assert repr(value) == "value: $(timestamp)" + assert fmt(value) == "value: $(timestamp)" -def test_repr_exact(): +def test_repr_exact(snapshot): # NB: This is the only exact repr test. Do # not add new exact repr tests. New repr tests # should only check for the presence of substrings. @@ -217,21 +212,12 @@ def test_repr_exact(): [("col", "int64"), ("col2", "string"), ("col3", "double")], name="t", ).mutate(col4=lambda t: t.col2.length()) - result = repr(table) - expected = """\ -r0 := UnboundTable: t - col int64 - col2 string - col3 float64 -Selection[r0] - selections: - r0 - col4: StringLength(r0.col2)""" - assert result == expected + result = fmt(table) + snapshot.assert_match(result, "repr.txt") -def test_complex_repr(): +def test_complex_repr(snapshot): t = ( ibis.table(dict(a="int64"), name="t") .filter([lambda t: t.a < 42, lambda t: t.a >= 42]) @@ -240,7 +226,9 @@ def test_complex_repr(): .aggregate(y=lambda t: t.a.sum()) .limit(10) ) - repr(t) + result = fmt(t) + + snapshot.assert_match(result, "repr.txt") def test_value_exprs_repr(): @@ -259,82 +247,107 @@ def test_show_types(monkeypatch): assert "# float64" in repr(expr.sum()) -@pytest.mark.parametrize("cls", set(ops.TableNode.__subclasses__()) - {Projection}) -def test_tables_have_format_rules(cls): - assert cls in ibis.expr.format.fmt_table_op.registry +def test_schema_truncation(monkeypatch, snapshot): + schema = dict(zip(string.ascii_lowercase[:20], ["string"] * 20)) + t = ibis.table(schema, name="t") + monkeypatch.setattr(ibis.options.repr, "table_columns", 0) + with pytest.raises(ValueError): + fmt(t) -@pytest.mark.parametrize("cls", [ops.PhysicalTable, ops.TableNode]) -def test_tables_have_format_value_rules(cls): - assert cls in ibis.expr.format.fmt_value.registry + monkeypatch.setattr(ibis.options.repr, "table_columns", 1) + result = fmt(t) + assert util.VERTICAL_ELLIPSIS not in result + snapshot.assert_match(result, "repr1.txt") + monkeypatch.setattr(ibis.options.repr, "table_columns", 8) + result = fmt(t) + assert util.VERTICAL_ELLIPSIS in result + snapshot.assert_match(result, "repr8.txt") -@pytest.mark.parametrize( - "f", - [ - lambda t1, _: t1.count(), - lambda t1, t2: t1.join(t2, t1.a == t2.a).count(), - lambda t1, t2: ibis.union(t1, t2).count(), - ], -) -def test_table_value_expr(f): + monkeypatch.setattr(ibis.options.repr, "table_columns", 1000) + result = fmt(t) + assert util.VERTICAL_ELLIPSIS not in result + snapshot.assert_match(result, "repr_all.txt") + + +def test_table_count_expr(snapshot): t1 = ibis.table([("a", "int"), ("b", "float")], name="t1") t2 = ibis.table([("a", "int"), ("b", "float")], name="t2") - expr = f(t1, t2) - repr(expr) # smoketest + cnt = t1.count() + join_cnt = t1.join(t2, t1.a == t2.a).count() + union_cnt = ibis.union(t1, t2).count() + + snapshot.assert_match(fmt(cnt), "cnt_repr.txt") + snapshot.assert_match(fmt(join_cnt), "join_repr.txt") + snapshot.assert_match(fmt(union_cnt), "union_repr.txt") -def test_window_no_group_by(): + +def test_window_no_group_by(snapshot): t = ibis.table(dict(a="int64", b="string"), name="t") expr = t.a.mean().over(ibis.window(preceding=0)) - result = repr(expr) + result = fmt(expr) + assert "group_by=[]" not in result + snapshot.assert_match(result, "repr.txt") -def test_window_group_by(): +def test_window_group_by(snapshot): t = ibis.table(dict(a="int64", b="string"), name="t") expr = t.a.mean().over(ibis.window(group_by=t.b)) - result = repr(expr) + result = fmt(expr) assert "start=0" not in result assert "group_by=[r0.b]" in result + snapshot.assert_match(result, "repr.txt") -def test_fillna(): +def test_fillna(snapshot): t = ibis.table(dict(a="int64", b="string"), name="t") expr = t.fillna({"a": 3}) - repr(expr) + snapshot.assert_match(fmt(expr), "fillna_dict_repr.txt") expr = t[["a"]].fillna(3) - repr(expr) + snapshot.assert_match(fmt(expr), "fillna_int_repr.txt") expr = t[["b"]].fillna("foo") - repr(expr) + snapshot.assert_match(fmt(expr), "fillna_str_repr.txt") -def test_asof_join(): - left = ibis.table([("time1", "int32"), ("value", "double")]) - right = ibis.table([("time2", "int32"), ("value2", "double")]) +def test_asof_join(snapshot): + left = ibis.table([("time1", "int32"), ("value", "double")], name="left") + right = ibis.table([("time2", "int32"), ("value2", "double")], name="right") joined = left.asof_join(right, [("time1", "time2")]).inner_join( right, left.value == right.value2 ) - rep = repr(joined) - assert rep.count("InnerJoin") == 1 - assert rep.count("AsOfJoin") == 1 + result = fmt(joined) + assert result.count("InnerJoin") == 1 + assert result.count("AsOfJoin") == 1 -def test_two_inner_joins(): - left = ibis.table([("time1", "int32"), ("value", "double"), ("a", "string")]) - right = ibis.table([("time2", "int32"), ("value2", "double"), ("b", "string")]) + snapshot.assert_match(result, "repr.txt") + + +def test_two_inner_joins(snapshot): + left = ibis.table( + [("time1", "int32"), ("value", "double"), ("a", "string")], name="left" + ) + right = ibis.table( + [("time2", "int32"), ("value2", "double"), ("b", "string")], name="right" + ) joined = left.inner_join(right, left.a == right.b).inner_join( right, left.value == right.value2 ) - rep = repr(joined) - assert rep.count("InnerJoin") == 2 + result = fmt(joined) + assert result.count("InnerJoin") == 2 -def test_destruct_selection(): + snapshot.assert_match(result, "repr.txt") + + +def test_destruct_selection(snapshot): table = ibis.table([("col", "int64")], name="t") @udf.reduction( @@ -345,10 +358,11 @@ def multi_output_udf(v): return v.sum(), v.mean() expr = table.aggregate(multi_output_udf(table["col"]).destructure()) - result = repr(expr) + result = fmt(expr) assert "sum: StructField(ReductionVectorizedUDF" in result assert "mean: StructField(ReductionVectorizedUDF" in result + snapshot.assert_match(result, "repr.txt") @pytest.mark.parametrize( @@ -356,11 +370,59 @@ def multi_output_udf(v): [(42, None, "42"), ("42", None, "'42'"), (42, "double", "42.0")], ) def test_format_literal(literal, typ, output): - assert repr(ibis.literal(literal, type=typ)) == output + expr = ibis.literal(literal, type=typ) + assert fmt(expr) == output -def test_format_dummy_table(): +def test_format_dummy_table(snapshot): t = ops.DummyTable([ibis.array([1], type="array").name("foo")]).to_expr() - result = repr(t) + + result = fmt(t) assert "DummyTable" in result assert "foo array" in result + snapshot.assert_match(result, "repr.txt") + + +def test_format_in_memory_table(snapshot): + t = ibis.memtable([(1, 2), (3, 4), (5, 6)], columns=["x", "y"]) + expr = t.x.sum() + t.y.sum() + + result = fmt(expr) + assert "InMemoryTable" in result + snapshot.assert_match(result, "repr.txt") + + +def test_format_new_relational_operation(alltypes, snapshot): + class MyRelation(ops.Relation): + parent: ops.Relation + kind: str + + @property + def schema(self): + return self.parent.schema + + table = MyRelation(alltypes, kind="foo").to_expr() + expr = table[table, table.a.name("a2")] + result = fmt(expr) + + snapshot.assert_match(result, "repr.txt") + + +def test_format_new_value_operation(alltypes, snapshot): + class Inc(ops.Value): + arg: ops.Value + + @property + def dtype(self): + return self.arg.dtype + + @property + def shape(self): + return self.arg.shape + + expr = Inc(alltypes.a).to_expr().name("incremented") + result = fmt(expr) + last_line = result.splitlines()[-1] + + assert "Inc" in result + assert last_line == "incremented: Inc(r0.a)" diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index da0ee893d7cb..2fa60b3b31f2 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -83,9 +83,9 @@ def __hash__(self): return hash((self.__class__, self._arg)) def _repr(self) -> str: - from ibis.expr.format import fmt + from ibis.expr.format import pretty - return fmt(self) + return pretty(self) def equals(self, other): """Return whether this expression is _structurally_ equivalent to `other`. diff --git a/ibis/tests/expr/mocks.py b/ibis/tests/expr/mocks.py index 72be6380e4a6..e1010ccd3864 100644 --- a/ibis/tests/expr/mocks.py +++ b/ibis/tests/expr/mocks.py @@ -14,6 +14,8 @@ from __future__ import annotations +import contextlib + import pytest import sqlalchemy as sa @@ -23,341 +25,7 @@ from ibis.backends.base.sql.alchemy import AlchemyCompiler from ibis.backends.base.sql.alchemy.datatypes import AlchemyType from ibis.expr.schema import Schema - -MOCK_TABLES = { - "alltypes": [ - ("a", "int8"), - ("b", "int16"), - ("c", "int32"), - ("d", "int64"), - ("e", "float32"), - ("f", "float64"), - ("g", "string"), - ("h", "boolean"), - ("i", "timestamp"), - ("j", "date"), - ("k", "time"), - ], - "foo_t": [("key1", "string"), ("key2", "string"), ("value1", "double")], - "bar_t": [("key1", "string"), ("key2", "string"), ("value2", "double")], - "t1": [("key1", "string"), ("key2", "string"), ("value1", "double")], - "t2": [("key1", "string"), ("key2", "string"), ("value2", "double")], - "t3": [("key2", "string"), ("key3", "string"), ("value3", "double")], - "t4": [("key3", "string"), ("value4", "double")], - "bar": [("x", "double"), ("job", "string")], - "foo": [ - ("job", "string"), - ("dept_id", "string"), - ("year", "int32"), - ("y", "double"), - ], - "star1": [ - ("c", "int32"), - ("f", "double"), - ("foo_id", "string"), - ("bar_id", "string"), - ], - "star2": [ - ("foo_id", "string"), - ("value1", "double"), - ("value3", "double"), - ], - "star3": [("bar_id", "string"), ("value2", "double")], - "test1": [("c", "int32"), ("f", "double"), ("g", "string")], - "test2": [("key", "string"), ("value", "double")], - "tpch_region": [ - ("r_regionkey", "int16"), - ("r_name", "string"), - ("r_comment", "string"), - ], - "tpch_nation": [ - ("n_nationkey", "int16"), - ("n_name", "string"), - ("n_regionkey", "int16"), - ("n_comment", "string"), - ], - "tpch_lineitem": [ - ("l_orderkey", "int64"), - ("l_partkey", "int64"), - ("l_suppkey", "int64"), - ("l_linenumber", "int32"), - ("l_quantity", "decimal(12,2)"), - ("l_extendedprice", "decimal(12,2)"), - ("l_discount", "decimal(12,2)"), - ("l_tax", "decimal(12,2)"), - ("l_returnflag", "string"), - ("l_linestatus", "string"), - ("l_shipdate", "string"), - ("l_commitdate", "string"), - ("l_receiptdate", "string"), - ("l_shipinstruct", "string"), - ("l_shipmode", "string"), - ("l_comment", "string"), - ], - "tpch_customer": [ - ("c_custkey", "int64"), - ("c_name", "string"), - ("c_address", "string"), - ("c_nationkey", "int16"), - ("c_phone", "string"), - ("c_acctbal", "decimal"), - ("c_mktsegment", "string"), - ("c_comment", "string"), - ], - "tpch_orders": [ - ("o_orderkey", "int64"), - ("o_custkey", "int64"), - ("o_orderstatus", "string"), - ("o_totalprice", "decimal(12,2)"), - ("o_orderdate", "string"), - ("o_orderpriority", "string"), - ("o_clerk", "string"), - ("o_shippriority", "int32"), - ("o_comment", "string"), - ], - "functional_alltypes": [ - ("id", "int32"), - ("bool_col", "boolean"), - ("tinyint_col", "int8"), - ("smallint_col", "int16"), - ("int_col", "int32"), - ("bigint_col", "int64"), - ("float_col", "float32"), - ("double_col", "float64"), - ("date_string_col", "string"), - ("string_col", "string"), - ("timestamp_col", "timestamp"), - ("year", "int32"), - ("month", "int32"), - ], - "airlines": [ - ("year", "int32"), - ("month", "int32"), - ("day", "int32"), - ("dayofweek", "int32"), - ("dep_time", "int32"), - ("crs_dep_time", "int32"), - ("arr_time", "int32"), - ("crs_arr_time", "int32"), - ("carrier", "string"), - ("flight_num", "int32"), - ("tail_num", "int32"), - ("actual_elapsed_time", "int32"), - ("crs_elapsed_time", "int32"), - ("airtime", "int32"), - ("arrdelay", "int32"), - ("depdelay", "int32"), - ("origin", "string"), - ("dest", "string"), - ("distance", "int32"), - ("taxi_in", "int32"), - ("taxi_out", "int32"), - ("cancelled", "int32"), - ("cancellation_code", "string"), - ("diverted", "int32"), - ("carrier_delay", "int32"), - ("weather_delay", "int32"), - ("nas_delay", "int32"), - ("security_delay", "int32"), - ("late_aircraft_delay", "int32"), - ], - "tpcds_customer": [ - ("c_customer_sk", "int64"), - ("c_customer_id", "string"), - ("c_current_cdemo_sk", "int32"), - ("c_current_hdemo_sk", "int32"), - ("c_current_addr_sk", "int32"), - ("c_first_shipto_date_sk", "int32"), - ("c_first_sales_date_sk", "int32"), - ("c_salutation", "string"), - ("c_first_name", "string"), - ("c_last_name", "string"), - ("c_preferred_cust_flag", "string"), - ("c_birth_day", "int32"), - ("c_birth_month", "int32"), - ("c_birth_year", "int32"), - ("c_birth_country", "string"), - ("c_login", "string"), - ("c_email_address", "string"), - ("c_last_review_date", "string"), - ], - "tpcds_customer_address": [ - ("ca_address_sk", "bigint"), - ("ca_address_id", "string"), - ("ca_street_number", "string"), - ("ca_street_name", "string"), - ("ca_street_type", "string"), - ("ca_suite_number", "string"), - ("ca_city", "string"), - ("ca_county", "string"), - ("ca_state", "string"), - ("ca_zip", "string"), - ("ca_country", "string"), - ("ca_gmt_offset", "decimal(5,2)"), - ("ca_location_type", "string"), - ], - "tpcds_customer_demographics": [ - ("cd_demo_sk", "bigint"), - ("cd_gender", "string"), - ("cd_marital_status", "string"), - ("cd_education_status", "string"), - ("cd_purchase_estimate", "int"), - ("cd_credit_rating", "string"), - ("cd_dep_count", "int"), - ("cd_dep_employed_count", "int"), - ("cd_dep_college_count", "int"), - ], - "tpcds_date_dim": [ - ("d_date_sk", "bigint"), - ("d_date_id", "string"), - ("d_date", "string"), - ("d_month_seq", "int"), - ("d_week_seq", "int"), - ("d_quarter_seq", "int"), - ("d_year", "int"), - ("d_dow", "int"), - ("d_moy", "int"), - ("d_dom", "int"), - ("d_qoy", "int"), - ("d_fy_year", "int"), - ("d_fy_quarter_seq", "int"), - ("d_fy_week_seq", "int"), - ("d_day_name", "string"), - ("d_quarter_name", "string"), - ("d_holiday", "string"), - ("d_weekend", "string"), - ("d_following_holiday", "string"), - ("d_first_dom", "int"), - ("d_last_dom", "int"), - ("d_same_day_ly", "int"), - ("d_same_day_lq", "int"), - ("d_current_day", "string"), - ("d_current_week", "string"), - ("d_current_month", "string"), - ("d_current_quarter", "string"), - ("d_current_year", "string"), - ], - "tpcds_household_demographics": [ - ("hd_demo_sk", "bigint"), - ("hd_income_band_sk", "int"), - ("hd_buy_potential", "string"), - ("hd_dep_count", "int"), - ("hd_vehicle_count", "int"), - ], - "tpcds_item": [ - ("i_item_sk", "bigint"), - ("i_item_id", "string"), - ("i_rec_start_date", "string"), - ("i_rec_end_date", "string"), - ("i_item_desc", "string"), - ("i_current_price", "decimal(7,2)"), - ("i_wholesale_cost", "decimal(7,2)"), - ("i_brand_id", "int"), - ("i_brand", "string"), - ("i_class_id", "int"), - ("i_class", "string"), - ("i_category_id", "int"), - ("i_category", "string"), - ("i_manufact_id", "int"), - ("i_manufact", "string"), - ("i_size", "string"), - ("i_formulation", "string"), - ("i_color", "string"), - ("i_units", "string"), - ("i_container", "string"), - ("i_manager_id", "int"), - ("i_product_name", "string"), - ], - "tpcds_promotion": [ - ("p_promo_sk", "bigint"), - ("p_promo_id", "string"), - ("p_start_date_sk", "int"), - ("p_end_date_sk", "int"), - ("p_item_sk", "int"), - ("p_cost", "decimal(15,2)"), - ("p_response_target", "int"), - ("p_promo_name", "string"), - ("p_channel_dmail", "string"), - ("p_channel_email", "string"), - ("p_channel_catalog", "string"), - ("p_channel_tv", "string"), - ("p_channel_radio", "string"), - ("p_channel_press", "string"), - ("p_channel_event", "string"), - ("p_channel_demo", "string"), - ("p_channel_details", "string"), - ("p_purpose", "string"), - ("p_discount_active", "string"), - ], - "tpcds_store": [ - ("s_store_sk", "bigint"), - ("s_store_id", "string"), - ("s_rec_start_date", "string"), - ("s_rec_end_date", "string"), - ("s_closed_date_sk", "int"), - ("s_store_name", "string"), - ("s_number_employees", "int"), - ("s_floor_space", "int"), - ("s_hours", "string"), - ("s_manager", "string"), - ("s_market_id", "int"), - ("s_geography_class", "string"), - ("s_market_desc", "string"), - ("s_market_manager", "string"), - ("s_division_id", "int"), - ("s_division_name", "string"), - ("s_company_id", "int"), - ("s_company_name", "string"), - ("s_street_number", "string"), - ("s_street_name", "string"), - ("s_street_type", "string"), - ("s_suite_number", "string"), - ("s_city", "string"), - ("s_county", "string"), - ("s_state", "string"), - ("s_zip", "string"), - ("s_country", "string"), - ("s_gmt_offset", "decimal(5,2)"), - ("s_tax_precentage", "decimal(5,2)"), - ], - "tpcds_store_sales": [ - ("ss_sold_time_sk", "bigint"), - ("ss_item_sk", "bigint"), - ("ss_customer_sk", "bigint"), - ("ss_cdemo_sk", "bigint"), - ("ss_hdemo_sk", "bigint"), - ("ss_addr_sk", "bigint"), - ("ss_store_sk", "bigint"), - ("ss_promo_sk", "bigint"), - ("ss_ticket_number", "int"), - ("ss_quantity", "int"), - ("ss_wholesale_cost", "decimal(7,2)"), - ("ss_list_price", "decimal(7,2)"), - ("ss_sales_price", "decimal(7,2)"), - ("ss_ext_discount_amt", "decimal(7,2)"), - ("ss_ext_sales_price", "decimal(7,2)"), - ("ss_ext_wholesale_cost", "decimal(7,2)"), - ("ss_ext_list_price", "decimal(7,2)"), - ("ss_ext_tax", "decimal(7,2)"), - ("ss_coupon_amt", "decimal(7,2)"), - ("ss_net_paid", "decimal(7,2)"), - ("ss_net_paid_inc_tax", "decimal(7,2)"), - ("ss_net_profit", "decimal(7,2)"), - ("ss_sold_date_sk", "bigint"), - ], - "tpcds_time_dim": [ - ("t_time_sk", "bigint"), - ("t_time_id", "string"), - ("t_time", "int"), - ("t_hour", "int"), - ("t_minute", "int"), - ("t_second", "int"), - ("t_am_pm", "string"), - ("t_shift", "string"), - ("t_sub_shift", "string"), - ("t_meal_time", "string"), - ], -} +from ibis.expr.tests.conftest import MOCK_TABLES class MockBackend(BaseSQLBackend): @@ -368,6 +36,7 @@ class MockBackend(BaseSQLBackend): def __init__(self): super().__init__() self.executed_queries = [] + self.sql_query_schemas = {} def do_connect(self): pass @@ -432,6 +101,15 @@ def _load_into_cache(self, *_): def _clean_up_cached_table(self, _): raise NotImplementedError(self.name) + def _get_schema_using_query(self, query): + return self.sql_query_schemas[query] + + @contextlib.contextmanager + def set_query_schema(self, query, schema): + self.sql_query_schemas[query] = schema + yield + self.sql_query_schemas.pop(query, None) + def table_from_schema(name, meta, schema, *, database: str | None = None): # Convert Ibis schema to SQLA table diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_format_sql_query_result/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_format_sql_query_result/repr.txt new file mode 100644 index 000000000000..1cd75a4812d8 --- /dev/null +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_format_sql_query_result/repr.txt @@ -0,0 +1,43 @@ +r0 := DatabaseTable: airlines + year int32 + month int32 + day int32 + dayofweek int32 + dep_time int32 + crs_dep_time int32 + arr_time int32 + crs_arr_time int32 + carrier string + flight_num int32 + tail_num int32 + actual_elapsed_time int32 + crs_elapsed_time int32 + airtime int32 + arrdelay int32 + depdelay int32 + origin string + dest string + distance int32 + taxi_in int32 + taxi_out int32 + cancelled int32 + cancellation_code string + diverted int32 + carrier_delay int32 + weather_delay int32 + nas_delay int32 + security_delay int32 + late_aircraft_delay int32 + +r1 := SQLStringView[r0]: foo + query: + SELECT carrier, mean(arrdelay) AS avg_arrdelay FROM airlines GROUP BY 1 ORDER … + schema: + carrier string + avg_arrdelay float64 + +Selection[r1] + selections: + carrier: r1.carrier + avg_arrdelay: Round(r1.avg_arrdelay, digits=1) + island: Lowercase(r1.carrier) \ No newline at end of file diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt new file mode 100644 index 000000000000..6266bb50b1cc --- /dev/null +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt @@ -0,0 +1,22 @@ +r0 := DatabaseTable: test2 + key string + value float64 + +r1 := DatabaseTable: test1 + c int32 + f float64 + g string + +r2 := Selection[r1] + predicates: + r1.f > 0 + +r3 := InnerJoin[r0, r2] r2.g == r0.key + +Aggregation[r3] + metrics: + foo: Mean(r2.f - r0.value) + bar: Sum(r2.f) + by: + g: r2.g + key: r0.key \ No newline at end of file diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt new file mode 100644 index 000000000000..d5e678285698 --- /dev/null +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt @@ -0,0 +1,51 @@ +r0 := DatabaseTable: airlines + year int32 + month int32 + day int32 + dayofweek int32 + dep_time int32 + crs_dep_time int32 + arr_time int32 + crs_arr_time int32 + carrier string + flight_num int32 + tail_num int32 + actual_elapsed_time int32 + crs_elapsed_time int32 + airtime int32 + arrdelay int32 + depdelay int32 + origin string + dest string + distance int32 + taxi_in int32 + taxi_out int32 + cancelled int32 + cancellation_code string + diverted int32 + carrier_delay int32 + weather_delay int32 + nas_delay int32 + security_delay int32 + late_aircraft_delay int32 + +r1 := Selection[r0] + selections: + arrdelay: r0.arrdelay + dest: r0.dest + +r2 := Selection[r1] + selections: + r1 + dest_avg: WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) + dev: r1.arrdelay - WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) + +r3 := Selection[r2] + predicates: + NotNull(r2.dev) + +r4 := Selection[r3] + sort_keys: + desc r3.dev + +Limit[r4, n=10] \ No newline at end of file diff --git a/ibis/tests/expr/test_format_sql_operations.py b/ibis/tests/expr/test_format_sql_operations.py new file mode 100644 index 000000000000..fc1deeabccac --- /dev/null +++ b/ibis/tests/expr/test_format_sql_operations.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import ibis +from ibis import _ + + +def test_format_sql_query_result(con, snapshot): + t = con.table("airlines") + + query = """ + SELECT carrier, mean(arrdelay) AS avg_arrdelay + FROM airlines + GROUP BY 1 + ORDER BY 2 DESC + """ + schema = ibis.schema({"carrier": "string", "avg_arrdelay": "double"}) + + with con.set_query_schema(query, schema): + expr = t.sql(query) + # name is autoincremented so we need to set it manually to make the + # snapshot stable + expr = expr.op().copy(name="foo").to_expr() + + expr = expr.mutate( + island=_.carrier.lower(), + avg_arrdelay=_.avg_arrdelay.round(1), + ) + + snapshot.assert_match(repr(expr), "repr.txt") + + +def test_memoize_database_table(con, snapshot): + table = con.table("test1") + table2 = con.table("test2") + + filter_pred = table["f"] > 0 + table3 = table[filter_pred] + join_pred = table3["g"] == table2["key"] + + joined = table2.inner_join(table3, [join_pred]) + + met1 = (table3["f"] - table2["value"]).mean().name("foo") + expr = joined.aggregate( + [met1, table3["f"].sum().name("bar")], by=[table3["g"], table2["key"]] + ) + + result = repr(expr) + assert result.count("test1") == 1 + assert result.count("test2") == 1 + + snapshot.assert_match(result, "repr.txt") + + +def test_memoize_insert_sort_key(con, snapshot): + table = con.table("airlines") + + t = table["arrdelay", "dest"] + expr = t.group_by("dest").mutate( + dest_avg=t.arrdelay.mean(), dev=t.arrdelay - t.arrdelay.mean() + ) + + worst = expr[expr.dev.notnull()].order_by(ibis.desc("dev")).limit(10) + + result = repr(worst) + assert result.count("airlines") == 1 + + snapshot.assert_match(result, "repr.txt")