Skip to content

Commit

Permalink
fix(polars): various polars enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored and cpcloud committed Nov 9, 2023
1 parent 2bffa5a commit 5948dd6
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 78 deletions.
151 changes: 79 additions & 72 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import calendar
import functools
import math
import operator
from collections.abc import Mapping
from functools import partial
from functools import partial, reduce, singledispatch
from math import isnan

import numpy as np
import pandas as pd
Expand All @@ -19,16 +19,30 @@
from ibis.expr.operations.udf import InputType


def _assert_literal(op):
def _expr_method(expr, op, methods):
for m in methods:
if hasattr(expr, m):
return getattr(expr, m)
raise com.TranslationError(
f"Failed to translate {op}; expected expression method(s) not found:\n{methods!r}"
)


def _literal_value(op, nan_as_none=False):
# TODO(kszucs): broadcast and apply UDF on two columns using concat_list
# TODO(kszucs): better error message
if not isinstance(op, ops.Literal):
if op is None:
return None
elif not isinstance(op, ops.Literal):
raise com.UnsupportedArgumentError(
f"Polars does not support columnar argument {op.name}"
)
else:
value = op.value
return None if nan_as_none and isnan(value) else value


@functools.singledispatch
@singledispatch
def translate(expr, *, ctx):
raise NotImplementedError(expr)

Expand Down Expand Up @@ -188,7 +202,7 @@ def selection(op, **kw):

if op.predicates:
predicates = map(partial(translate, **kw), op.predicates)
predicate = functools.reduce(operator.and_, predicates)
predicate = reduce(operator.and_, predicates)
lf = lf.filter(predicate)

selections = []
Expand Down Expand Up @@ -247,7 +261,7 @@ def aggregation(op, **kw):

if op.predicates:
lf = lf.filter(
functools.reduce(
reduce(
operator.and_,
map(partial(translate, **kw), op.predicates),
)
Expand Down Expand Up @@ -329,8 +343,7 @@ def fillna(op, **kw):
if isinstance(op.replacements, Mapping):
value = op.replacements.get(name)
else:
_assert_literal(op.replacements)
value = op.replacements.value
value = _literal_value(op.replacements)

if value is not None:
if dtype.is_floating():
Expand Down Expand Up @@ -427,6 +440,7 @@ def in_values(op, **kw):
ops.RStrip: "strip_chars_end",
ops.Lowercase: "to_lowercase",
ops.Uppercase: "to_uppercase",
ops.Capitalize: "to_titlecase",
}


Expand All @@ -448,12 +462,6 @@ def string_unary(op, **kw):
return method()


@translate.register(ops.Capitalize)
def captalize(op, **kw):
arg = translate(op.arg, **kw)
return arg.map_elements(lambda x: x.capitalize())


@translate.register(ops.Reverse)
def reverse(op, **kw):
arg = translate(op.arg, **kw)
Expand All @@ -463,8 +471,8 @@ def reverse(op, **kw):
@translate.register(ops.StringSplit)
def string_split(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.delimiter)
return arg.str.split(op.delimiter.value)
delim = _literal_value(op.delimiter)
return arg.str.split(by=delim)


@translate.register(ops.StringReplace)
Expand All @@ -478,15 +486,15 @@ def string_replace(op, **kw):
@translate.register(ops.StartsWith)
def string_startswith(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.start)
return arg.str.starts_with(op.start.value)
start = _literal_value(op.start)
return arg.str.starts_with(start)


@translate.register(ops.EndsWith)
def string_endswith(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.end)
return arg.str.ends_with(op.end.value)
end = _literal_value(op.end)
return arg.str.ends_with(end)


@translate.register(ops.StringConcat)
Expand All @@ -498,82 +506,84 @@ def string_concat(op, **kw):
@translate.register(ops.StringJoin)
def string_join(op, **kw):
args = [translate(arg, **kw) for arg in op.arg]
_assert_literal(op.sep)
sep = op.sep.value
sep = _literal_value(op.sep)
return pl.concat_str(args, separator=sep)


@translate.register(ops.Substring)
def string_substrig(op, **kw):
def string_substring(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.start)
_assert_literal(op.length)
return arg.str.slice(op.start.value, op.length.value)
return arg.str.slice(
offset=_literal_value(op.start),
length=_literal_value(op.length),
)


@translate.register(ops.StringContains)
def string_contains(op, **kw):
haystack = translate(op.haystack, **kw)
_assert_literal(op.needle)
return haystack.str.contains(op.needle.value)
return haystack.str.contains(
pattern=_literal_value(op.needle),
literal=True,
)


@translate.register(ops.RegexSearch)
def regex_search(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.pattern)
return arg.str.contains(op.pattern.value)
return arg.str.contains(
pattern=_literal_value(op.pattern),
literal=False,
)


@translate.register(ops.RegexExtract)
def regex_extract(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.pattern)
_assert_literal(op.index)
return arg.str.extract(op.pattern.value, op.index.value)
return arg.str.extract(
pattern=_literal_value(op.pattern),
group_index=_literal_value(op.index),
)


@translate.register(ops.RegexReplace)
def regex_replace(op, **kw):
arg = translate(op.arg, **kw)
pattern = translate(op.pattern, **kw)
replacement = translate(op.replacement, **kw)
return arg.str.replace_all(pattern, replacement)
return arg.str.replace_all(
pattern=pattern,
value=replacement,
)


@translate.register(ops.LPad)
def lpad(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.length)
_assert_literal(op.pad)
return arg.str.rjust(op.length.value, op.pad.value)
_lpad = _expr_method(arg.str, "lpad", ["pad_start", "rjust"])
return _lpad(_literal_value(op.length), _literal_value(op.pad))


@translate.register(ops.RPad)
def rpad(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.length)
_assert_literal(op.pad)
return arg.str.ljust(op.length.value, op.pad.value)
_rpad = _expr_method(arg.str, "rpad", ["pad_end", "ljust"])
return _rpad(_literal_value(op.length), _literal_value(op.pad))


@translate.register(ops.StrRight)
def str_right(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.nchars)
return arg.str.slice(-op.nchars.value, None)
nchars = _literal_value(op.nchars)
return arg.str.slice(-nchars, None)


@translate.register(ops.Round)
def round(op, **kw):
arg = translate(op.arg, **kw)
typ = dtype_to_polars(op.dtype)
if op.digits is not None:
_assert_literal(op.digits)
digits = op.digits.value
else:
digits = 0
return arg.round(digits).cast(typ)
digits = _literal_value(op.digits)
return arg.round(digits or 0).cast(typ)


@translate.register(ops.Radians)
Expand All @@ -592,38 +602,33 @@ def degrees(op, **kw):
def clip(op, **kw):
arg = translate(op.arg, **kw)

def clipper(arg, expr):
return pl.when(arg.is_null()).then(arg).otherwise(expr)
lower = _literal_value(op.lower)
upper = _literal_value(op.upper)

lower = op.lower
upper = op.upper

if lower is not None and upper is not None:
_assert_literal(lower)
_assert_literal(upper)
return clipper(arg, arg.clip(lower.value, upper.value))
if vparse(pl.__version__) >= vparse("0.19.12"):
if not (lower is None and upper is None):
return arg.clip(lower, upper)
elif lower is not None and upper is not None:
return arg.clip(lower, upper)
elif lower is not None:
_assert_literal(lower)
return clipper(arg, arg.clip_min(lower.value))
return arg.clip_min(lower)
elif upper is not None:
_assert_literal(op.upper)
return clipper(arg, arg.clip_max(upper.value))
else:
raise com.TranslationError("No lower or upper bound specified")
return arg.clip_max(upper)

raise com.TranslationError("No lower or upper bound specified")


@translate.register(ops.Log)
def log(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.base)
return arg.log(op.base.value)
return arg.log(base=_literal_value(op.base))


@translate.register(ops.Repeat)
def repeat(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.times)
return arg.map_elements(lambda x: x * op.times.value)
n_times = _literal_value(op.times)
return pl.concat_str([arg] * n_times, separator="")


@translate.register(ops.Sign)
Expand Down Expand Up @@ -732,8 +737,8 @@ def timestamp_now(op, **_):
@translate.register(ops.Strftime)
def strftime(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.format_str)
return arg.dt.strftime(op.format_str.value)
fmt = _literal_value(op.format_str)
return arg.dt.strftime(format=fmt)


@translate.register(ops.Date)
Expand Down Expand Up @@ -833,8 +838,10 @@ def interval_from_integer(op, **kw):
@translate.register(ops.StringToTimestamp)
def string_to_timestamp(op, **kw):
arg = translate(op.arg, **kw)
_assert_literal(op.format_str)
return arg.str.strptime(pl.Datetime, op.format_str.value)
return arg.str.strptime(
dtype=pl.Datetime,
format=_literal_value(op.format_str),
)


@translate.register(ops.TimestampDiff)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def dtype_to_polars(dtype):

@dtype_to_polars.register(dt.Decimal)
def from_ibis_decimal(dtype):
return pl.Decimal(dtype.precision, dtype.scale)
return pl.Decimal(precision=dtype.precision, scale=dtype.scale)


@dtype_to_polars.register(dt.Timestamp)
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,6 @@ def uses_java_re(t):
["pyspark"],
raises=com.OperationNotDefinedError,
),
pytest.mark.broken(
["polars"],
raises=AttributeError,
reason="'NoneType' object has no attribute 'name'",
),
pytest.mark.broken(
["mssql"],
reason="substr requires 3 arguments",
Expand Down

0 comments on commit 5948dd6

Please sign in to comment.