Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Warn on inefficient use of map_elements for temporal attributes/methods #14529

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 69 additions & 26 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,36 +132,43 @@ class OpNames:
)
)

# python functions that we can map to native expressions
# python attrs/funcs that map to native expressions
_PYTHON_ATTRS_MAP = {
"date": "dt.date()",
"day": "dt.day()",
"hour": "dt.hour()",
"microsecond": "dt.microsecond()",
"minute": "dt.minute()",
"month": "dt.month()",
"second": "dt.second()",
"year": "dt.year()",
}
_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"}
_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"}
_PYTHON_METHODS_MAP = {
# string
"lower": "str.to_lowercase",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
# temporal
"isoweekday": "dt.weekday",
"date": "dt.date",
"time": "dt.time",
}

_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: module.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [_NUMPY_MODULE_ALIASES],
"attribute_name": [],
"function_name": [_NUMPY_FUNCTIONS],
},
# lambda x: module.func(x)
_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: numpy.func(x)
# lambda x: numpy.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [_NUMPY_MODULE_ALIASES],
"attribute_name": [],
"function_name": [_NUMPY_FUNCTIONS],
},
# lambda x: json.loads(x)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [],
Expand All @@ -171,7 +178,7 @@ class OpNames:
"attribute_name": [],
"function_name": [{"loads"}],
},
# lambda x: module.func(x, CONSTANT)
# lambda x: datetime.strptime(x, CONSTANT)
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [{"LOAD_CONST"}],
Expand All @@ -194,13 +201,12 @@ class OpNames:
]
# In addition to `lambda x: func(x)`, also support cases when a unary operation
# has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`.
_FUNCTION_KINDS = [
# Dict entry 1 has incompatible type "str": "object";
# expected "str": "list[AbstractSet[str]]"
_MODULE_FUNCTIONS = [
{**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item]
for kind in _FUNCTION_KINDS
for kind in _MODULE_FUNCTIONS
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
for unary in [[set(OpNames.UNARY)], []]
]
_RE_IMPLICIT_BOOL = re.compile(r'pl\.col\("([^"]*)"\) & pl\.col\("\1"\)\.(.+)')


def _get_all_caller_variables() -> dict[str, Any]:
Expand Down Expand Up @@ -252,6 +258,12 @@ def __init__(self, function: Callable[[Any], Any], map_target: MapTarget):
instructions=original_instructions,
)

def _omit_implicit_bool(self, expr: str) -> str:
"""Drop extraneous/implied bool (eg: `pl.col("d") & pl.col("d").dt.date()`)."""
while _RE_IMPLICIT_BOOL.search(expr):
expr = _RE_IMPLICIT_BOOL.sub(repl=r'pl.col("\1").\2', string=expr)
return expr

@staticmethod
def _get_param_name(function: Callable[[Any], Any]) -> str | None:
"""Return single function parameter name."""
Expand Down Expand Up @@ -415,11 +427,13 @@ def to_expression(self, col: str) -> str | None:
# constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn
if "pl.col(" not in polars_expr:
return None
elif self._map_target == "series":
target_name = self._get_target_name(col, polars_expr)
return polars_expr.replace(f'pl.col("{col}")', target_name)
else:
return polars_expr
polars_expr = self._omit_implicit_bool(polars_expr)
if self._map_target == "series":
target_name = self._get_target_name(col, polars_expr)
return polars_expr.replace(f'pl.col("{col}")', target_name)
else:
return polars_expr

def warn(
self,
Expand Down Expand Up @@ -656,7 +670,8 @@ def _matches(
idx: int,
*,
opnames: list[AbstractSet[str]],
argvals: list[AbstractSet[Any] | dict[Any, Any]] | None,
argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None,
is_attr: bool = False,
) -> list[Instruction]:
"""
Check if a sequence of Instructions matches the specified ops/argvals.
Expand All @@ -669,9 +684,19 @@ def _matches(
The full opname sequence that defines a match.
argvals
Associated argvals that must also match (in same position as opnames).
is_attr
Indicate if the match is expected to represent attribute access.
"""
n_required_ops, argvals = len(opnames), argvals or []
instructions = self._instructions[idx : idx + n_required_ops]
idx_offset = idx + n_required_ops
if (
is_attr
and (trailing_inst := self._instructions[idx_offset : idx_offset + 1])
and trailing_inst[0].opname in OpNames.CALL # not pure attr if called
):
return []

instructions = self._instructions[idx:idx_offset]
if len(instructions) == n_required_ops and all(
inst.opname in match_opnames
and (match_argval is None or inst.argval in match_argval)
Expand Down Expand Up @@ -702,12 +727,30 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]:
self._rewrite_functions,
self._rewrite_methods,
self._rewrite_builtins,
self._rewrite_attrs,
)
):
updated_instructions.append(inst)
idx += increment or 1
return updated_instructions

def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int:
"""Replace python attribute lookup with synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}],
argvals=[None, _PYTHON_ATTRS_MAP],
is_attr=True,
):
inst = matching_instructions[1]
expr_name = _PYTHON_ATTRS_MAP[inst.argval]
synthetic_call = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.extend([matching_instructions[0], synthetic_call])

return len(matching_instructions)

def _rewrite_builtins(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
Expand Down Expand Up @@ -738,7 +781,7 @@ def _rewrite_functions(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace function calls with a synthetic POLARS_EXPRESSION op."""
for function_kind in _FUNCTION_KINDS:
for function_kind in _MODULE_FUNCTIONS:
opnames: list[AbstractSet[str]] = [
{"LOAD_GLOBAL", "LOAD_DEREF"},
*function_kind["module_opname"],
Expand Down
89 changes: 49 additions & 40 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

import polars as pl
from polars.datatypes import DATETIME_DTYPES, DTYPE_TEMPORAL_UNITS, TEMPORAL_DTYPES
from polars.exceptions import ComputeError, TimeZoneAwareConstructorWarning
from polars.exceptions import (
ComputeError,
PolarsInefficientMapWarning,
TimeZoneAwareConstructorWarning,
)
from polars.testing import (
assert_frame_equal,
assert_series_equal,
Expand Down Expand Up @@ -947,45 +951,50 @@ def test_temporal_dtypes_map_elements(
)
const_dtm = datetime(2010, 9, 12)

assert_frame_equal(
df.with_columns(
[
# don't actually do any of this; native expressions are MUCH faster ;)
pl.col("timestamp")
.map_elements(lambda x: const_dtm, skip_nulls=skip_nulls)
.alias("const_dtm"),
pl.col("timestamp")
.map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls)
.alias("date"),
pl.col("timestamp")
.map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls)
.alias("time"),
]
),
pl.DataFrame(
[
(
datetime(2010, 9, 12, 10, 19, 54),
datetime(2010, 9, 12, 0, 0),
date(2010, 9, 12),
time(10, 19, 54),
),
(None, expected_value, None, None),
(
datetime(2009, 2, 13, 23, 31, 30),
datetime(2010, 9, 12, 0, 0),
date(2009, 2, 13),
time(23, 31, 30),
),
],
schema={
"timestamp": pl.Datetime("ms"),
"const_dtm": pl.Datetime("us"),
"date": pl.Date,
"time": pl.Time,
},
),
)
with pytest.warns(
PolarsInefficientMapWarning,
match=r"(?s)Replace this expression.*lambda x:",
):
assert_frame_equal(
df.with_columns(
[
# don't actually do this; native expressions are MUCH faster ;)
pl.col("timestamp")
.map_elements(lambda x: const_dtm, skip_nulls=skip_nulls)
.alias("const_dtm"),
# note: the below now trigger a PolarsInefficientMapWarning
pl.col("timestamp")
.map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls)
.alias("date"),
pl.col("timestamp")
.map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls)
.alias("time"),
]
),
pl.DataFrame(
[
(
datetime(2010, 9, 12, 10, 19, 54),
datetime(2010, 9, 12, 0, 0),
date(2010, 9, 12),
time(10, 19, 54),
),
(None, expected_value, None, None),
(
datetime(2009, 2, 13, 23, 31, 30),
datetime(2010, 9, 12, 0, 0),
date(2009, 2, 13),
time(23, 31, 30),
),
],
schema={
"timestamp": pl.Datetime("ms"),
"const_dtm": pl.Datetime("us"),
"date": pl.Date,
"time": pl.Time,
},
),
)


def test_timelike_init() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
),
# ---------------------------------------------
# temporal attributes/methods
# ---------------------------------------------
(
"f",
"lambda x: x.isoweekday()",
'pl.col("f").dt.weekday()',
),
(
"f",
"lambda x: x.hour + x.minute + x.second",
'(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()',
),
# ---------------------------------------------
# Bitwise shifts
# ---------------------------------------------
(
Expand Down Expand Up @@ -244,6 +257,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],
"d": ["2020-01-01", "2020-01-02", "2020-01-03"],
"e": [1.5, 2.4, 3.1],
"f": [
datetime(1999, 12, 31),
datetime(2024, 5, 6),
datetime(2077, 10, 20),
],
}
)
result_frame = df.select(
Expand All @@ -254,7 +272,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
x=pl.col(col),
y=pl.col(col).map_elements(eval(func)),
)
assert_frame_equal(result_frame, expected_frame)
assert_frame_equal(
result_frame,
expected_frame,
check_dtype=(".dt." not in suggested_expression),
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
)


@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
Expand Down Expand Up @@ -411,6 +433,15 @@ def test_expr_exact_warning_message() -> None:
assert len(warnings) == 1


def test_omit_implicit_bool() -> None:
parser = BytecodeParser(
function=lambda x: x and x and x.date(),
map_target="expr",
)
suggested_expression = parser.to_expression("d")
assert suggested_expression == 'pl.col("d").dt.date()'


def test_partial_functions_13523() -> None:
def plus(value, amount: int): # type: ignore[no-untyped-def]
return value + amount
Expand Down