Skip to content

Commit

Permalink
feat(python): warn on inefficient use of map_elements for temporal …
Browse files Browse the repository at this point in the history
…attribute access
  • Loading branch information
alexander-beedie committed Feb 16, 2024
1 parent a405016 commit e312866
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
59 changes: 39 additions & 20 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,36 +132,40 @@ class OpNames:
)
)

# python functions that we can map to native expressions
# python attrs/funcs that map to native expressions
_PYTHON_ATTRS_MAP = {
"date": "dt.date()",
"day": "dt.day()",
"hour": "dt.hour()",
"microsecond": "dt.microsecond()",
"minute": "dt.minute()",
"month": "dt.month()",
"second": "dt.second()",
"time": "dt.time()",
"year": "dt.year()",
}
_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"}
_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"}
_PYTHON_METHODS_MAP = {
"lower": "str.to_lowercase",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
"isoweekday": "dt.weekday",
}

_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: module.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [_NUMPY_MODULE_ALIASES],
"attribute_name": [],
"function_name": [_NUMPY_FUNCTIONS],
},
# lambda x: module.func(x)
_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: numpy.func(x)
# lambda x: numpy.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [_NUMPY_MODULE_ALIASES],
"attribute_name": [],
"function_name": [_NUMPY_FUNCTIONS],
},
# lambda x: json.loads(x)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [],
Expand All @@ -171,7 +175,7 @@ class OpNames:
"attribute_name": [],
"function_name": [{"loads"}],
},
# lambda x: module.func(x, CONSTANT)
# lambda x: datetime.strptime(x, CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [{"LOAD_CONST"}],
Expand All @@ -194,11 +198,9 @@ class OpNames:
]
# In addition to `lambda x: func(x)`, also support cases when a unary operation
# has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`.
_FUNCTION_KINDS = [
# Dict entry 1 has incompatible type "str": "object";
# expected "str": "list[AbstractSet[str]]"
_MODULE_FUNCTIONS = [
{**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item]
for kind in _FUNCTION_KINDS
for kind in _MODULE_FUNCTIONS
for unary in [[set(OpNames.UNARY)], []]
]

Expand Down Expand Up @@ -702,12 +704,29 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]:
self._rewrite_functions,
self._rewrite_methods,
self._rewrite_builtins,
self._rewrite_attrs,
)
):
updated_instructions.append(inst)
idx += increment or 1
return updated_instructions

def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int:
"""Replace python attribute lookup with synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}],
argvals=[None, _PYTHON_ATTRS_MAP],
):
inst = matching_instructions[1]
expr_name = _PYTHON_ATTRS_MAP[inst.argval]
synthetic_call = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.extend([matching_instructions[0], synthetic_call])

return len(matching_instructions)

def _rewrite_builtins(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
Expand Down Expand Up @@ -738,7 +757,7 @@ def _rewrite_functions(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace function calls with a synthetic POLARS_EXPRESSION op."""
for function_kind in _FUNCTION_KINDS:
for function_kind in _MODULE_FUNCTIONS:
opnames: list[AbstractSet[str]] = [
{"LOAD_GLOBAL", "LOAD_DEREF"},
*function_kind["module_opname"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
),
# ---------------------------------------------
# temporal attributes/methods
# ---------------------------------------------
(
"f",
"lambda x: x.isoweekday()",
'pl.col("f").dt.weekday()',
),
(
"f",
"lambda x: x.hour + x.minute + x.second",
'(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()',
),
# ---------------------------------------------
# Bitwise shifts
# ---------------------------------------------
(
Expand Down Expand Up @@ -244,6 +257,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],
"d": ["2020-01-01", "2020-01-02", "2020-01-03"],
"e": [1.5, 2.4, 3.1],
"f": [
datetime(1999, 12, 31),
datetime(2024, 5, 6),
datetime(2077, 10, 20),
],
}
)
result_frame = df.select(
Expand All @@ -254,7 +272,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
x=pl.col(col),
y=pl.col(col).map_elements(eval(func)),
)
assert_frame_equal(result_frame, expected_frame)
assert_frame_equal(
result_frame,
expected_frame,
check_dtype=(".dt." not in suggested_expression),
)


@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
Expand Down

0 comments on commit e312866

Please sign in to comment.