Skip to content

Commit

Permalink
refactor(ir): remove unnecessary complexity introduced by variadic an…
Browse files Browse the repository at this point in the history
…notation
  • Loading branch information
kszucs authored and cpcloud committed Nov 28, 2022
1 parent 3cd764f commit 698314b
Show file tree
Hide file tree
Showing 22 changed files with 148 additions and 218 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _varargs_call(sa_func, t, args):

def varargs(sa_func):
def formatter(t, op):
return _varargs_call(sa_func, t, op.args)
return _varargs_call(sa_func, t, op.arg)

return formatter

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def cast(translator, op):

def varargs(func_name):
def varargs_formatter(translator, op):
return helpers.format_call(translator, func_name, *op.args)
return helpers.format_call(translator, func_name, *op.arg)

return varargs_formatter

Expand Down Expand Up @@ -210,7 +210,7 @@ def hash(translator, op):


def concat(translator, op):
joined_args = ', '.join(map(translator.translate, op.args))
joined_args = ', '.join(map(translator.translate, op.arg))
return f"concat({joined_args})"


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _string_join(translator, op):


def _string_concat(translator, op):
args_formatted = ", ".join(map(translator.translate, op.args))
args_formatted = ", ".join(map(translator.translate, op.arg))
return f"arrayStringConcat([{args_formatted}])"


Expand Down
21 changes: 21 additions & 0 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from ibis.backends.pandas.execution import constants
from ibis.backends.pandas.execution.generic import (
_execute_binary_op_impl,
coalesce,
compute_row_reduction,
execute_between,
execute_cast_series_array,
execute_cast_series_generic,
Expand Down Expand Up @@ -493,3 +495,22 @@ def execute_simple_case_series(op, value, whens, thens, otherwise, **kwargs):
otherwise = np.nan
raw = np.select([value == when for when in whens], thens, otherwise)
return wrap_case_result(raw, op.to_expr())


@execute_node.register(ops.Greatest, tuple)
def execute_node_greatest_list(op, values, **kwargs):
values = [execute(arg, **kwargs) for arg in values]
return compute_row_reduction(np.maximum.reduce, values, axis=0)


@execute_node.register(ops.Least, tuple)
def execute_node_least_list(op, values, **kwargs):
values = [execute(arg, **kwargs) for arg in values]
return compute_row_reduction(np.minimum.reduce, values, axis=0)


@execute_node.register(ops.Coalesce, tuple)
def execute_node_coalesce(op, values, **kwargs):
# TODO: this is slow
values = [execute(arg, **kwargs) for arg in values]
return compute_row_reduction(coalesce, values)
5 changes: 4 additions & 1 deletion ibis/backends/dask/execution/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dask.dataframe as dd
import dask.dataframe.groupby as ddgb
import numpy as np
import pandas as pd

import ibis.expr.operations as ops
from ibis.backends.dask.dispatch import execute_node
Expand Down Expand Up @@ -83,7 +84,9 @@ def vectorize_object(op, arg, *args, **kwargs):


@execute_node.register(
ops.Log, dd.Series, (dd.Series, numbers.Real, decimal.Decimal, type(None))
ops.Log,
dd.Series,
(dd.Series, pd.Series, numbers.Real, decimal.Decimal, type(None)),
)
def execute_series_log_with_base(op, data, base, **kwargs):
if data.dtype == np.dtype(np.object_):
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/dask/execution/strings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import operator

import dask.dataframe as dd
import dask.dataframe.groupby as ddgb
Expand Down Expand Up @@ -203,6 +204,12 @@ def iterate(value, start_iter=start.items(), end_iter=end.items()):
return data.map(iterate)


@execute_node.register(ops.StringConcat, tuple)
def execute_node_string_concat(op, values, **kwargs):
values = [execute(arg, **kwargs) for arg in values]
return functools.reduce(operator.add, values)


@execute_node.register(ops.StringSQLLike, ddgb.SeriesGroupBy, str, str)
def execute_string_like_series_groupby_string(op, data, pattern, escape, **kwargs):
return execute_string_like_series_string(
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,13 +436,14 @@ def elementwise_udf(op):

@translate.register(ops.StringConcat)
def string_concat(op):
return df.functions.concat(*map(translate, op.args))
return df.functions.concat(*map(translate, op.arg))


@translate.register(ops.RegexExtract)
def regex_extract(op):
arg = translate(op.arg)
pattern = translate(ops.StringConcat("(", op.pattern, ")"))
concat = ops.StringConcat(("(", op.pattern, ")"))
pattern = translate(concat)
if (index := getattr(op.index, "value", None)) is None:
raise ValueError(
"re_extract `index` expressions must be literals. "
Expand Down
22 changes: 13 additions & 9 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,10 @@ def execute_alias(op, data, **kwargs):
return data


@execute_node.register(ops.StringConcat, [object])
def execute_node_string_concat(op, *args, **kwargs):
return functools.reduce(operator.add, args)
@execute_node.register(ops.StringConcat, tuple)
def execute_node_string_concat(op, values, **kwargs):
values = [execute(arg, **kwargs) for arg in values]
return functools.reduce(operator.add, values)


@execute_node.register(ops.StringJoin, collections.abc.Sequence)
Expand Down Expand Up @@ -1215,19 +1216,22 @@ def compute_row_reduction(func, values, **kwargs):
return pd.Series(raw).squeeze()


@execute_node.register(ops.Greatest, [object])
def execute_node_greatest_list(op, *values, **kwargs):
@execute_node.register(ops.Greatest, tuple)
def execute_node_greatest_list(op, values, **kwargs):
values = [execute(arg, **kwargs) for arg in values]
return compute_row_reduction(np.maximum.reduce, values, axis=0)


@execute_node.register(ops.Least, [object])
def execute_node_least_list(op, *values, **kwargs):
@execute_node.register(ops.Least, tuple)
def execute_node_least_list(op, values, **kwargs):
values = [execute(arg, **kwargs) for arg in values]
return compute_row_reduction(np.minimum.reduce, values, axis=0)


@execute_node.register(ops.Coalesce, [object])
def execute_node_coalesce(op, *values, **kwargs):
@execute_node.register(ops.Coalesce, tuple)
def execute_node_coalesce(op, values, **kwargs):
# TODO: this is slow
values = [execute(arg, **kwargs) for arg in values]
return compute_row_reduction(coalesce, values)


Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,19 @@ def searched_case(op):

@translate.register(ops.Coalesce)
def coalesce(op):
arg = list(map(translate, op.args))
arg = list(map(translate, op.arg))
return pl.coalesce(arg)


@translate.register(ops.Least)
def least(op):
arg = [translate(arg) for arg in op.args]
arg = [translate(arg) for arg in op.arg]
return pl.min(arg)


@translate.register(ops.Greatest)
def greatest(op):
arg = [translate(arg) for arg in op.args]
arg = [translate(arg) for arg in op.arg]
return pl.max(arg)


Expand Down Expand Up @@ -421,7 +421,7 @@ def string_endswith(op):

@translate.register(ops.StringConcat)
def string_concat(op):
args = [translate(arg) for arg in op.args]
args = [translate(arg) for arg in op.arg]
return pl.concat_str(args)


Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def compile_arbitrary(t, op, **kwargs):
@compiles(ops.Coalesce)
def compile_coalesce(t, op, **kwargs):
kwargs["raw"] = False # override to force column literals
src_columns = [t.translate(col, **kwargs) for col in op.args]
src_columns = [t.translate(col, **kwargs) for col in op.arg]
if len(src_columns) == 1:
return src_columns[0]
else:
Expand All @@ -669,7 +669,7 @@ def compile_coalesce(t, op, **kwargs):
@compiles(ops.Greatest)
def compile_greatest(t, op, **kwargs):
kwargs["raw"] = False # override to force column literals
src_columns = [t.translate(col, **kwargs) for col in op.args]
src_columns = [t.translate(col, **kwargs) for col in op.arg]
if len(src_columns) == 1:
return src_columns[0]
else:
Expand All @@ -679,7 +679,7 @@ def compile_greatest(t, op, **kwargs):
@compiles(ops.Least)
def compile_least(t, op, **kwargs):
kwargs["raw"] = False # override to force column literals
src_columns = [t.translate(col, **kwargs) for col in op.args]
src_columns = [t.translate(col, **kwargs) for col in op.arg]
if len(src_columns) == 1:
return src_columns[0]
else:
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def compile_string_split(t, op, **kwargs):
@compiles(ops.StringConcat)
def compile_string_concat(t, op, **kwargs):
kwargs["raw"] = False # override to force column literals
src_columns = [t.translate(arg, **kwargs) for arg in op.args]
src_columns = [t.translate(arg, **kwargs) for arg in op.arg]
return F.concat(*src_columns)


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sqlite/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _string_join(t, op):


def _string_concat(t, op):
return functools.reduce(operator.add, map(t.translate, op.args))
return functools.reduce(operator.add, map(t.translate, op.arg))


def _date_from_ymd(t, op):
Expand Down
Loading

0 comments on commit 698314b

Please sign in to comment.