Skip to content

Commit

Permalink
fix(polars): bump lower bound to 0.19.8 and clean up a bunch of backc…
Browse files Browse the repository at this point in the history
…ompat code
  • Loading branch information
cpcloud committed Oct 13, 2023
1 parent a0f24e8 commit 462bd17
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 55 deletions.
42 changes: 9 additions & 33 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ def literal(op, **_):
value = pl.Series("", value)
typ = dtype_to_polars(dtype)
val = pl.lit(value, dtype=typ)
try:
return val.implode()
except AttributeError: # pragma: no cover
return val.list() # pragma: no cover
return val.implode()
elif dtype.is_struct():
values = [
pl.lit(v, dtype=dtype_to_polars(dtype[k])).alias(k)
Expand Down Expand Up @@ -317,10 +314,7 @@ def dropna(op, **kw):

if op.how == "all":
cols = pl.col(subset) if subset else pl.all()
try:
return lf.filter(~pl.all_horizontal(cols.is_null()))
except AttributeError:
return lf.filter(~pl.all(cols.is_null()))
return lf.filter(~pl.all_horizontal(cols.is_null()))

return lf.drop_nulls(subset)

Expand Down Expand Up @@ -404,19 +398,13 @@ def coalesce(op, **kw):
@translate.register(ops.Least)
def least(op, **kw):
arg = [translate(arg, **kw) for arg in op.arg]
try:
return pl.min_horizontal(arg)
except AttributeError:
return pl.min(arg)
return pl.min_horizontal(arg)


@translate.register(ops.Greatest)
def greatest(op, **kw):
arg = [translate(arg, **kw) for arg in op.arg]
try:
return pl.max_horizontal(arg)
except AttributeError:
return pl.max(arg)
return pl.max_horizontal(arg)


@translate.register(ops.InColumn)
Expand All @@ -430,10 +418,7 @@ def in_column(op, **kw):
def in_values(op, **kw):
value = translate(op.value, **kw)
options = list(map(translate, op.options))
try:
return pl.any_horizontal([value == option for option in options])
except AttributeError:
return pl.any([value == option for option in options])
return pl.any_horizontal([value == option for option in options])


_string_unary = {
Expand All @@ -449,7 +434,7 @@ def in_values(op, **kw):
def string_length(op, **kw):
arg = translate(op.arg, **kw)
typ = dtype_to_polars(op.dtype)
return arg.str.lengths().cast(typ)
return arg.str.len_bytes().cast(typ)


@translate.register(ops.StringUnary)
Expand Down Expand Up @@ -515,10 +500,7 @@ def string_join(op, **kw):
args = [translate(arg, **kw) for arg in op.arg]
_assert_literal(op.sep)
sep = op.sep.value
try:
return pl.concat_str(args, separator=sep)
except TypeError: # pragma: no cover
return pl.concat_str(args, sep=sep) # pragma: no cover
return pl.concat_str(args, separator=sep)


@translate.register(ops.Substring)
Expand Down Expand Up @@ -839,21 +821,15 @@ def timestamp_diff(op, **kw):
@translate.register(ops.ArrayLength)
def array_length(op, **kw):
arg = translate(op.arg, **kw)
try:
return arg.arr.lengths()
except AttributeError:
return arg.list.lengths()
return arg.list.len()


@translate.register(ops.ArrayConcat)
def array_concat(op, **kw):
result, *rest = map(partial(translate, **kw), op.arg)

for arg in rest:
try:
result = result.arr.concat(arg)
except AttributeError:
result = result.list.concat(arg)
result = result.list.concat(arg)

return result

Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,14 +706,6 @@ def mean_and_std(v):
lambda t, where: t.count(where=where),
lambda t, where: len(t[where]),
id="count_star",
marks=[
pytest.mark.broken(
["polars"],
raises=ComputeError,
reason="polars seems broken for named ungrouped scalar reductions with no filter",
strict=False,
)
],
),
param(
lambda t, where: t.string_col.collect(where=where),
Expand Down
20 changes: 8 additions & 12 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,23 +1245,19 @@ def test_floating_mod(backend, alltypes, df):
],
),
param(
"float_col",
marks=[
pytest.mark.broken(
"polars",
strict=False,
reason="output type is float64 instead of the expected float32",
raises=AssertionError,
),
pytest.mark.notimpl(["druid"], raises=ZeroDivisionError),
],
"float_col", marks=pytest.mark.notimpl(["druid"], raises=ZeroDivisionError)
),
param(
"double_col", marks=pytest.mark.notimpl(["druid"], raises=ZeroDivisionError)
),
],
)
@pytest.mark.notyet(["duckdb", "mysql", "pyspark", "sqlite"], raises=AssertionError)
@pytest.mark.notyet(["mysql", "pyspark"], raises=AssertionError)
@pytest.mark.notyet(
["duckdb", "sqlite"],
raises=AssertionError,
reason="returns NULL when dividing by zero",
)
@pytest.mark.notyet(["mssql"], raises=sa.exc.OperationalError)
@pytest.mark.notyet(["postgres"], raises=sa.exc.DataError)
@pytest.mark.notyet(["snowflake"], raises=sa.exc.ProgrammingError)
Expand All @@ -1274,7 +1270,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator):
expected = df[column].div(denominator)
expected = backend.default_series_rename(expected).astype("float64")

backend.assert_series_equal(result, expected)
backend.assert_series_equal(result.astype("float64"), expected)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ graphviz = { version = ">=0.16,<1", optional = true }
impyla = { version = ">=0.17,<1", optional = true }
oracledb = { version = ">=1.3.1,<2", optional = true }
packaging = { version = ">=21.3,<24", optional = true }
polars = { version = ">=0.19,<1", optional = true }
polars = { version = ">=0.19.8,<1", optional = true }
psycopg2 = { version = ">=2.8.4,<3", optional = true }
pymssql = { version = ">=2.2.5,<3", optional = true }
pydata-google-auth = { version = ">=1.4.0,<2", optional = true }
Expand Down

0 comments on commit 462bd17

Please sign in to comment.