Skip to content

Commit

Permalink
fix(clip): preserve nulls when clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Sep 25, 2023
1 parent 3e1de97 commit c12dfa4
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 29 deletions.
17 changes: 14 additions & 3 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,23 @@ def _rewrite_string_contains(op):

@rewrites(ops.Clip)
def _rewrite_clip(op):
arg = ops.Cast(op.arg, op.dtype)
dtype = op.dtype
arg = ops.Cast(op.arg, dtype)

arg_is_null = ops.IsNull(arg)

if (upper := op.upper) is not None:
arg = ops.Least((arg, ops.Cast(upper, op.dtype)))
clipped_lower = ops.Least((arg, ops.Cast(upper, dtype)))
if dtype.nullable:
arg = ops.Where(arg_is_null, arg, clipped_lower)
else:
arg = clipped_lower

if (lower := op.lower) is not None:
arg = ops.Greatest((arg, ops.Cast(lower, op.dtype)))
clipped_upper = ops.Greatest((arg, ops.Cast(lower, dtype)))
if dtype.nullable:
arg = ops.Where(arg_is_null, arg, clipped_upper)
else:
arg = clipped_upper

return arg
4 changes: 2 additions & 2 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,10 +761,10 @@ def _struct_column(op, **kw):
def _clip(op, **kw):
arg = translate_val(op.arg, **kw)
if (upper := op.upper) is not None:
arg = f"least({translate_val(upper, **kw)}, {arg})"
arg = f"if(isNull({arg}), NULL, least({translate_val(upper, **kw)}, {arg}))"

if (lower := op.lower) is not None:
arg = f"greatest({translate_val(lower, **kw)}, {arg})"
arg = f"if(isNull({arg}), NULL, greatest({translate_val(lower, **kw)}, {arg}))"

return arg

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/flink/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def _clip(translator: ExprTranslator, op: ops.Node) -> str:

if op.upper is not None:
upper = translator.translate(op.upper)
arg = f"IF({arg} > {upper}, {upper}, {arg})"
arg = f"IF({arg} > {upper} AND {arg} IS NOT NULL, {upper}, {arg})"

if op.lower is not None:
lower = translator.translate(op.lower)
arg = f"IF({arg} < {lower}, {lower}, {arg})"
arg = f"IF({arg} < {lower} AND {arg} IS NOT NULL, {lower}, {arg})"

return f"CAST({arg} AS {_to_pyflink_types[type(op.dtype)]!s})"

Expand Down
24 changes: 15 additions & 9 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,16 +629,22 @@ def degrees(op, **kw):
def clip(op, **kw):
arg = translate(op.arg, **kw)

if op.lower is not None and op.upper is not None:
_assert_literal(op.lower)
def clipper(arg, expr):
return pl.when(arg.is_null()).then(arg).otherwise(expr)

lower = op.lower
upper = op.upper

if lower is not None and upper is not None:
_assert_literal(lower)
_assert_literal(upper)
return clipper(arg, arg.clip(lower.value, upper.value))
elif lower is not None:
_assert_literal(lower)
return clipper(arg, arg.clip_min(lower.value))
elif upper is not None:
_assert_literal(op.upper)
return arg.clip(op.lower.value, op.upper.value)
elif op.lower is not None:
_assert_literal(op.lower)
return arg.clip_min(op.lower.value)
elif op.upper is not None:
_assert_literal(op.upper)
return arg.clip_max(op.upper.value)
return clipper(arg, arg.clip_max(upper.value))
else:
raise com.TranslationError("No lower or upper bound specified")

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,11 +775,11 @@ def compile_clip(t, op, **kwargs):

def column_min(value, limit):
"""Return values greater than or equal to `limit`."""
return F.when(value < limit, limit).otherwise(value)
return F.when((value < limit) & ~F.isnull(value), limit).otherwise(value)

def column_max(value, limit):
"""Return values less than or equal to `limit`."""
return F.when(value > limit, limit).otherwise(value)
return F.when((value > limit) & ~F.isnull(value), limit).otherwise(value)

def clip(column, lower_value, upper_value):
return column_max(column_min(column, F.lit(lower_value)), F.lit(upper_value))
Expand Down
27 changes: 21 additions & 6 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,25 +1364,40 @@ def test_random(con):
@pytest.mark.parametrize(
("ibis_func", "pandas_func"),
[
(lambda x: x.clip(lower=0), lambda x: x.clip(lower=0)),
(lambda x: x.clip(lower=0.0), lambda x: x.clip(lower=0.0)),
(lambda x: x.clip(upper=0), lambda x: x.clip(upper=0)),
pytest.param(
param(lambda x: x.clip(lower=0), lambda x: x.clip(lower=0), id="lower-int"),
param(
lambda x: x.clip(lower=0.0), lambda x: x.clip(lower=0.0), id="lower-float"
),
param(lambda x: x.clip(upper=0), lambda x: x.clip(upper=0), id="upper-int"),
param(
lambda x: x.clip(lower=x - 1, upper=x + 1),
lambda x: x.clip(lower=x - 1, upper=x + 1),
marks=pytest.mark.notimpl(
"polars",
raises=com.UnsupportedArgumentError,
reason="Polars does not support columnar argument Subtract(int_col, 1)",
),
id="lower-upper-expr",
),
(
param(
lambda x: x.clip(lower=0, upper=1),
lambda x: x.clip(lower=0, upper=1),
id="lower-upper-int",
),
(
param(
lambda x: x.clip(lower=0, upper=1.0),
lambda x: x.clip(lower=0, upper=1.0),
id="lower-upper-float",
),
param(
lambda x: x.nullif(1).clip(lower=0),
lambda x: x.where(x != 1).clip(lower=0),
id="null-lower",
),
param(
lambda x: x.nullif(1).clip(upper=0),
lambda x: x.where(x != 1).clip(upper=0),
id="null-upper",
),
],
)
Expand Down
15 changes: 10 additions & 5 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def clip(
) -> NumericValue:
"""Trim values outside of `lower` and `upper` bounds.
`NULL` values are preserved and are not replaced with bounds.
Parameters
----------
lower
Expand All @@ -198,20 +200,23 @@ def clip(
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"values": range(8)})
>>> t = ibis.memtable(
... {"values": [None, 2, 3, None, 5, None, None, 8]},
... schema=dict(values="int"),
... )
>>> t.values.clip(lower=3, upper=6)
┏━━━━━━━━━━━━━━━━━━━━┓
┃ Clip(values, 3, 6) ┃
┡━━━━━━━━━━━━━━━━━━━━┩
│ int64 │
├────────────────────┤
│ NULL │
│ 3 │
│ 3 │
│ 3 │
│ 3 │
│ 4 │
│ NULL │
│ 5 │
│ 6 │
│ NULL │
│ NULL │
│ 6 │
└────────────────────┘
"""
Expand Down

1 comment on commit c12dfa4

@ibis-squawk-bot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 3.

Benchmark suite Current: c12dfa4 Previous: b37804a Ratio
ibis/tests/benchmarks/test_benchmarks.py::test_compile[medium-bigquery] 38.59023856264532 iter/sec (stddev: 0.02522526440080695) 603.5295570002766 iter/sec (stddev: 0.000035398257191080955) 15.64
ibis/tests/benchmarks/test_benchmarks.py::test_compile[small-bigquery] 1058.5437059295434 iter/sec (stddev: 0.00020878136335229994) 10602.847663506958 iter/sec (stddev: 0.00005834248372791223) 10.02
ibis/tests/benchmarks/test_benchmarks.py::test_compile[large-bigquery] 16.982299020809773 iter/sec (stddev: 0.04448098340867773) 146.4406797920701 iter/sec (stddev: 0.00006932716190392247) 8.62
ibis/tests/benchmarks/test_benchmarks.py::test_compile_with_drops[bigquery] 26.344669970659027 iter/sec (stddev: 0.003728085605853006) 130.02477424437095 iter/sec (stddev: 0.000247510166318326) 4.94
ibis/tests/benchmarks/test_benchmarks.py::test_compile[medium-druid] 3.733620931361614 iter/sec (stddev: 1.7571268460453324) 184.37177656676133 iter/sec (stddev: 0.00046351400859641047) 49.38

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.