Skip to content

Commit

Permalink
feat(python): warn on inefficient use of map_elements for temporal …
Browse files Browse the repository at this point in the history
…attribute access
  • Loading branch information
alexander-beedie committed Feb 16, 2024
1 parent a405016 commit 07e3b74
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 23 deletions.
74 changes: 52 additions & 22 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,36 +132,42 @@ class OpNames:
)
)

# python functions that we can map to native expressions
# python attrs/funcs that map to native expressions
_PYTHON_ATTRS_MAP = {
"date": "dt.date()",
"day": "dt.day()",
"hour": "dt.hour()",
"microsecond": "dt.microsecond()",
"minute": "dt.minute()",
"month": "dt.month()",
"second": "dt.second()",
"year": "dt.year()",
}
_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"}
_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"}
_PYTHON_METHODS_MAP = {
# string
"lower": "str.to_lowercase",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
# temporal
"isoweekday": "dt.weekday",
"time": "dt.time",
}

_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: module.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [_NUMPY_MODULE_ALIASES],
"attribute_name": [],
"function_name": [_NUMPY_FUNCTIONS],
},
# lambda x: module.func(x)
_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: numpy.func(x)
# lambda x: numpy.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [_NUMPY_MODULE_ALIASES],
"attribute_name": [],
"function_name": [_NUMPY_FUNCTIONS],
},
# lambda x: json.loads(x)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [],
Expand All @@ -171,7 +177,7 @@ class OpNames:
"attribute_name": [],
"function_name": [{"loads"}],
},
# lambda x: module.func(x, CONSTANT)
# lambda x: datetime.strptime(x, CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [{"LOAD_CONST"}],
Expand All @@ -194,11 +200,9 @@ class OpNames:
]
# In addition to `lambda x: func(x)`, also support cases when a unary operation
# has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`.
_FUNCTION_KINDS = [
# Dict entry 1 has incompatible type "str": "object";
# expected "str": "list[AbstractSet[str]]"
_MODULE_FUNCTIONS = [
{**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item]
for kind in _FUNCTION_KINDS
for kind in _MODULE_FUNCTIONS
for unary in [[set(OpNames.UNARY)], []]
]

Expand Down Expand Up @@ -656,7 +660,8 @@ def _matches(
idx: int,
*,
opnames: list[AbstractSet[str]],
argvals: list[AbstractSet[Any] | dict[Any, Any]] | None,
argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None,
is_attr: bool = False,
) -> list[Instruction]:
"""
Check if a sequence of Instructions matches the specified ops/argvals.
Expand All @@ -669,9 +674,16 @@ def _matches(
The full opname sequence that defines a match.
argvals
Associated argvals that must also match (in same position as opnames).
is_attr
Indicate if the match is expected to represent attribute access.
"""
n_required_ops, argvals = len(opnames), argvals or []
instructions = self._instructions[idx : idx + n_required_ops]
idx_offset = idx + n_required_ops
if is_attr and (trailing_inst := self._instructions[idx_offset:1]):
if trailing_inst[0].opname == "CALL":
return []

instructions = self._instructions[idx:idx_offset]
if len(instructions) == n_required_ops and all(
inst.opname in match_opnames
and (match_argval is None or inst.argval in match_argval)
Expand Down Expand Up @@ -702,12 +714,30 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]:
self._rewrite_functions,
self._rewrite_methods,
self._rewrite_builtins,
self._rewrite_attrs,
)
):
updated_instructions.append(inst)
idx += increment or 1
return updated_instructions

def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int:
"""Replace python attribute lookup with synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}],
argvals=[None, _PYTHON_ATTRS_MAP],
is_attr=True,
):
inst = matching_instructions[1]
expr_name = _PYTHON_ATTRS_MAP[inst.argval]
synthetic_call = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.extend([matching_instructions[0], synthetic_call])

return len(matching_instructions)

def _rewrite_builtins(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
Expand Down Expand Up @@ -738,7 +768,7 @@ def _rewrite_functions(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace function calls with a synthetic POLARS_EXPRESSION op."""
for function_kind in _FUNCTION_KINDS:
for function_kind in _MODULE_FUNCTIONS:
opnames: list[AbstractSet[str]] = [
{"LOAD_GLOBAL", "LOAD_DEREF"},
*function_kind["module_opname"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
),
# ---------------------------------------------
# temporal attributes/methods
# ---------------------------------------------
(
"f",
"lambda x: x.isoweekday()",
'pl.col("f").dt.weekday()',
),
(
"f",
"lambda x: x.hour + x.minute + x.second",
'(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()',
),
# ---------------------------------------------
# Bitwise shifts
# ---------------------------------------------
(
Expand Down Expand Up @@ -244,6 +257,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],
"d": ["2020-01-01", "2020-01-02", "2020-01-03"],
"e": [1.5, 2.4, 3.1],
"f": [
datetime(1999, 12, 31),
datetime(2024, 5, 6),
datetime(2077, 10, 20),
],
}
)
result_frame = df.select(
Expand All @@ -254,7 +272,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
x=pl.col(col),
y=pl.col(col).map_elements(eval(func)),
)
assert_frame_equal(result_frame, expected_frame)
assert_frame_equal(
result_frame,
expected_frame,
check_dtype=(".dt." not in suggested_expression),
)


@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
Expand Down

0 comments on commit 07e3b74

Please sign in to comment.