Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Warn on inefficient use of map_elements for additional string functions #14565

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
Union,
)

from polars.utils.various import re_escape

if TYPE_CHECKING:
from dis import Instruction

Expand Down Expand Up @@ -147,12 +149,17 @@ class OpNames:
_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"}
_PYTHON_METHODS_MAP = {
# string
"endswith": "str.ends_with",
"lower": "str.to_lowercase",
"lstrip": "str.strip_chars_start",
"rstrip": "str.strip_chars_end",
"startswith": "str.starts_with",
"strip": "str.strip_chars",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
# temporal
"isoweekday": "dt.weekday",
"date": "dt.date",
"isoweekday": "dt.weekday",
"time": "dt.time",
}

Expand Down Expand Up @@ -576,7 +583,7 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str
# But, if e1 << e2 was valid, then e2 must have been positive.
# Hence, the output of 2**e2 can be safely cast to Int64, which
# may be necessary if chaining operations which assume Int64 output.
return f"({e1}*2**{e2}).cast(pl.Int64)"
return f"({e1} * 2**{e2}).cast(pl.Int64)"
elif op == ">>":
# Motivation for the cast is the same as in the '<<' case above.
return f"({e1} / 2**{e2}).cast(pl.Int64)"
Expand Down Expand Up @@ -685,7 +692,7 @@ def _matches(
argvals
Associated argvals that must also match (in same position as opnames).
is_attr
Indicate if the match is expected to represent attribute access.
Indicate if the match represents pure attribute access (cannot be called).
"""
n_required_ops, argvals = len(opnames), argvals or []
idx_offset = idx + n_required_ops
Expand Down Expand Up @@ -744,10 +751,10 @@ def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> i
):
inst = matching_instructions[1]
expr_name = _PYTHON_ATTRS_MAP[inst.argval]
synthetic_call = inst._replace(
px = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.extend([matching_instructions[0], synthetic_call])
updated_instructions.extend([matching_instructions[0], px])

return len(matching_instructions)

Expand All @@ -765,15 +772,15 @@ def _rewrite_builtins(
dtype = _PYTHON_CASTS_MAP[argval]
argval = f"cast(pl.{dtype})"

synthetic_call = inst1._replace(
px = inst1._replace(
opname="POLARS_EXPRESSION",
argval=argval,
argrepr=argval,
offset=inst2.offset,
)
# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst2._replace(offset=inst1.offset)
updated_instructions.extend((operand, synthetic_call))
updated_instructions.extend((operand, px))

return len(matching_instructions)

Expand Down Expand Up @@ -818,22 +825,24 @@ def _rewrite_functions(
return 0
else:
expr_name = inst2.argval
synthetic_call = inst1._replace(

px = inst1._replace(
opname="POLARS_EXPRESSION",
argval=expr_name,
argrepr=expr_name,
offset=inst3.offset,
)

# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst3._replace(offset=inst1.offset)
updated_instructions.extend(
(
operand,
matching_instructions[3 + attribute_count],
synthetic_call,
px,
)
if function_kind["argument_1_unary_opname"]
else (operand, synthetic_call)
else (operand, px)
)
return len(matching_instructions)

Expand All @@ -843,20 +852,40 @@ def _rewrite_methods(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace python method calls with synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[
OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"},
OpNames.CALL,
],
argvals=[_PYTHON_METHODS_MAP],
LOAD_METHOD = OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"}
if matching_instructions := (
# method call with one basic arg, eg: "s.endswith('!')"
self._matches(
idx,
opnames=[LOAD_METHOD, {"LOAD_CONST"}, OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
)
or
# method call with no arg, eg: "s.lower()"
self._matches(
idx,
opnames=[LOAD_METHOD, OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
)
):
inst = matching_instructions[0]
expr_name = _PYTHON_METHODS_MAP[inst.argval]
synthetic_call = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.append(synthetic_call)
expr = _PYTHON_METHODS_MAP[inst.argval]

if matching_instructions[1].opname == "LOAD_CONST":
param_value = matching_instructions[1].argval
if isinstance(param_value, tuple) and expr in (
"str.starts_with",
"str.ends_with",
):
starts, ends = ("^", "") if "starts" in expr else ("", "$")
rx = "|".join(re_escape(v) for v in param_value)
q = '"' if "'" in param_value else "'"
expr = f"str.contains(r{q}{starts}({rx}){ends}{q})"
else:
expr += f"({param_value!r})"

px = inst._replace(opname="POLARS_EXPRESSION", argval=expr, argrepr=expr)
updated_instructions.append(px)

return len(matching_instructions)

Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,11 @@ def parse_percentiles(
at_or_above_50_percentiles = [0.5, *at_or_above_50_percentiles]

return [*sub_50_percentiles, *at_or_above_50_percentiles]


def re_escape(s: str) -> str:
"""Escape a string for use in a Polars (Rust) regex."""
# note: almost the same as the standard python 're.escape' function, but
# escapes _only_ those metachars with meaning to the rust regex crate
re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,29 @@
'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',
),
# ---------------------------------------------
# string expr: case/cast ops
# string exprs
# ---------------------------------------------
("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.String).str.to_titlecase()'),
(
"b",
'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',
'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',
),
(
"b",
"lambda x: x.strip().startswith('#')",
"""pl.col("b").str.strip_chars().str.starts_with('#')""",
),
(
"b",
"""lambda x: x.rstrip().endswith(('!','#','?','"'))""",
"""pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""",
),
(
"b",
"""lambda x: x.lstrip().startswith(('!','#','?',"'"))""",
"""pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""",
),
# ---------------------------------------------
# json expr: load/extract
# ---------------------------------------------
Expand Down Expand Up @@ -186,12 +201,12 @@
(
"a",
"lambda x: (3 << (32-x)) & 3",
'(3*2**(32 - pl.col("a"))).cast(pl.Int64) & 3',
'(3 * 2**(32 - pl.col("a"))).cast(pl.Int64) & 3',
),
(
"a",
"lambda x: (x << 32) & 3",
'(pl.col("a")*2**32).cast(pl.Int64) & 3',
'(pl.col("a") * 2**32).cast(pl.Int64) & 3',
),
(
"a",
Expand Down
Loading