From 16803e10cacfac3b2d2667731d50b040f92ad90a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 16 May 2022 08:57:28 -0400 Subject: [PATCH] feat(datafusion): implement trig functions --- ibis/backends/datafusion/compiler.py | 133 +++++++++++++++------------ 1 file changed, 76 insertions(+), 57 deletions(-) diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index df2dd60b8eac..a7f06d20ccba 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -6,7 +6,6 @@ import pyarrow as pa import ibis.common.exceptions as com -import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir from ibis.backends.datafusion.datatypes import to_pyarrow_type @@ -23,24 +22,24 @@ def expression(expr): @translate.register(ops.Node) -def operation(op, expr): +def operation(op, _): raise com.OperationNotDefinedError(f'No translation rule for {type(op)}') @translate.register(ops.DatabaseTable) -def table(op, expr): +def table(op, _): name, _, client = op.args return client._context.table(name) @translate.register(ops.Alias) -def alias(op, expr): +def alias(op, _): arg = translate(op.arg) return arg.alias(op.name) @translate.register(ops.Literal) -def literal(op, expr): +def literal(op, _): if isinstance(op.value, (set, frozenset)): value = list(op.value) else: @@ -53,14 +52,14 @@ def literal(op, expr): @translate.register(ops.Cast) -def cast(op, expr): +def cast(op, _): arg = translate(op.arg) typ = to_pyarrow_type(op.to) return arg.cast(to=typ) @translate.register(ops.TableColumn) -def column(op, expr): +def column(op, _): table_op = op.table.op() if hasattr(table_op, "name"): @@ -70,7 +69,7 @@ def column(op, expr): @translate.register(ops.SortKey) -def sort_key(op, expr): +def sort_key(op, _): arg = translate(op.expr) return arg.sort(ascending=op.ascending) @@ -128,31 +127,31 @@ def aggregation(op, expr): @translate.register(ops.Not) -def invert(op, expr): +def invert(op, _): arg = translate(op.arg) return ~arg @translate.register(ops.Abs) -def abs(op, expr): +def abs(op, _): arg = translate(op.arg) return df.functions.abs(arg) @translate.register(ops.Ceil) -def ceil(op, expr): +def ceil(op, _): arg = translate(op.arg) return df.functions.ceil(arg).cast(pa.int64()) @translate.register(ops.Floor) -def floor(op, expr): +def floor(op, _): arg = translate(op.arg) return df.functions.floor(arg).cast(pa.int64()) @translate.register(ops.Round) -def round(op, expr): +def round(op, _): arg = translate(op.arg) if op.digits is not None: raise com.UnsupportedOperationError( @@ -162,79 +161,79 @@ def round(op, expr): @translate.register(ops.Ln) -def ln(op, expr): +def ln(op, _): arg = translate(op.arg) return df.functions.ln(arg) @translate.register(ops.Log2) -def log2(op, expr): +def log2(op, _): arg = translate(op.arg) return df.functions.log2(arg) @translate.register(ops.Log10) -def log10(op, expr): +def log10(op, _): arg = translate(op.arg) return df.functions.log10(arg) @translate.register(ops.Sqrt) -def sqrt(op, expr): +def sqrt(op, _): arg = translate(op.arg) return df.functions.sqrt(arg) @translate.register(ops.Strip) -def strip(op, expr): +def strip(op, _): arg = translate(op.arg) return df.functions.trim(arg) @translate.register(ops.LStrip) -def lstrip(op, expr): +def lstrip(op, _): arg = translate(op.arg) return df.functions.ltrim(arg) @translate.register(ops.RStrip) -def rstrip(op, expr): +def rstrip(op, _): arg = translate(op.arg) return df.functions.rtrim(arg) @translate.register(ops.Lowercase) -def lower(op, expr): +def lower(op, _): arg = translate(op.arg) return df.functions.lower(arg) @translate.register(ops.Uppercase) -def upper(op, expr): +def upper(op, _): arg = translate(op.arg) return df.functions.upper(arg) @translate.register(ops.Reverse) -def reverse(op, expr): +def reverse(op, _): arg = translate(op.arg) return df.functions.reverse(arg) @translate.register(ops.StringLength) -def strlen(op, expr): +def strlen(op, _): arg = translate(op.arg) return df.functions.character_length(arg) @translate.register(ops.Capitalize) -def capitalize(op, expr): +def capitalize(op, _): arg = translate(op.arg) return df.functions.initcap(arg) @translate.register(ops.Substring) -def substring(op, expr): +def substring(op, _): arg = translate(op.arg) start = translate(op.start + 1) length = translate(op.length) @@ -242,21 +241,21 @@ def substring(op, expr): @translate.register(ops.RegexExtract) -def regex_extract(op, expr): +def regex_extract(op, _): arg = translate(op.arg) pattern = translate(op.pattern) return df.functions.regexp_match(arg, pattern) @translate.register(ops.Repeat) -def repeat(op, expr): +def repeat(op, _): arg = translate(op.arg) times = translate(op.times) return df.functions.repeat(arg, times) @translate.register(ops.LPad) -def lpad(op, expr): +def lpad(op, _): arg = translate(op.arg) length = translate(op.length) pad = translate(op.pad) @@ -264,7 +263,7 @@ def lpad(op, expr): @translate.register(ops.RPad) -def rpad(op, expr): +def rpad(op, _): arg = translate(op.arg) length = translate(op.length) pad = translate(op.pad) @@ -272,67 +271,67 @@ def rpad(op, expr): @translate.register(ops.GreaterEqual) -def ge(op, expr): +def ge(op, _): return translate(op.left) >= translate(op.right) @translate.register(ops.LessEqual) -def le(op, expr): +def le(op, _): return translate(op.left) <= translate(op.right) @translate.register(ops.Greater) -def gt(op, expr): +def gt(op, _): return translate(op.left) > translate(op.right) @translate.register(ops.Less) -def lt(op, expr): +def lt(op, _): return translate(op.left) < translate(op.right) @translate.register(ops.Equals) -def eq(op, expr): +def eq(op, _): return translate(op.left) == translate(op.right) @translate.register(ops.NotEquals) -def ne(op, expr): +def ne(op, _): return translate(op.left) != translate(op.right) @translate.register(ops.Add) -def add(op, expr): +def add(op, _): return translate(op.left) + translate(op.right) @translate.register(ops.Subtract) -def sub(op, expr): +def sub(op, _): return translate(op.left) - translate(op.right) @translate.register(ops.Multiply) -def mul(op, expr): +def mul(op, _): return translate(op.left) * translate(op.right) @translate.register(ops.Divide) -def div(op, expr): +def div(op, _): return translate(op.left) / translate(op.right) @translate.register(ops.FloorDivide) -def floordiv(op, expr): +def floordiv(op, _): return df.functions.floor(translate(op.left) / translate(op.right)) @translate.register(ops.Modulus) -def mod(op, expr): +def mod(op, _): return translate(op.left) % translate(op.right) @translate.register(ops.Count) -def count(op, expr): +def count(op, _): op_arg = op.arg if isinstance(op_arg, ir.Table): arg = df.literal(1) @@ -342,25 +341,25 @@ def count(op, expr): @translate.register(ops.Sum) -def sum(op, expr): +def sum(op, _): arg = translate(op.arg) return df.functions.sum(arg) @translate.register(ops.Min) -def min(op, expr): +def min(op, _): arg = translate(op.arg) return df.functions.min(arg) @translate.register(ops.Max) -def max(op, expr): +def max(op, _): arg = translate(op.arg) return df.functions.max(arg) @translate.register(ops.Mean) -def mean(op, expr): +def mean(op, _): arg = translate(op.arg) return df.functions.avg(arg) @@ -375,35 +374,55 @@ def _prepare_contains_options(options): @translate.register(ops.ValueList) -def value_list(op, expr): +def value_list(op, _): return list(map(translate, op.values)) @translate.register(ops.Contains) -def contains(op, expr): +def contains(op, _): value = translate(op.value) options = _prepare_contains_options(op.options) return df.functions.in_list(value, options, negated=False) @translate.register(ops.NotContains) -def not_contains(op, expr): +def not_contains(op, _): value = translate(op.value) options = _prepare_contains_options(op.options) return df.functions.in_list(value, options, negated=True) @translate.register(ops.Negate) -def negate(op, expr): - op_arg = op.arg - arg = translate(op_arg) - if op_arg.type() == dt.boolean: - return ~arg - return df.lit(-1) * arg +def negate(op, _): + return df.lit(-1) * translate(op.arg) + + +@translate.register(ops.Acos) +@translate.register(ops.Asin) +@translate.register(ops.Atan) +@translate.register(ops.Cos) +@translate.register(ops.Sin) +@translate.register(ops.Tan) +def trig(op, _): + func_name = op.__class__.__name__.lower() + func = getattr(df.functions, func_name) + return func(translate(op.arg)) + + +@translate.register(ops.Atan2) +def atan2(op, _): + y, x = map(translate, op.args) + return df.functions.atan(y / x) + + +@translate.register(ops.Cot) +def cot(op, _): + x = translate(op.arg) + return df.functions.cos(x) / df.functions.sin(x) @translate.register(ops.ElementWiseVectorizedUDF) -def elementwise_udf(op, expr): +def elementwise_udf(op, _): udf = df.udf( op.func, input_types=list(map(to_pyarrow_type, op.input_type)),