From f79236118f4896cbeaf12b7c7c328f9651fcd0b4 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 20 Sep 2024 23:08:47 +0200 Subject: [PATCH] fix: DataFrame plot was raising when some extra keywords were passed to encodings (e.g. `x=alt.X(a, axis=alt.Axis(labelAngle=30))`) --- py-polars/polars/dataframe/plotting.py | 23 ++++++++++++------- .../unit/operations/namespaces/test_plot.py | 8 +++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting.py index ac787afd3882..eed8fbc62faf 100644 --- a/py-polars/polars/dataframe/plotting.py +++ b/py-polars/polars/dataframe/plotting.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, Callable, Dict, Union +from polars.dependencies import altair as alt + if TYPE_CHECKING: import sys - import altair as alt from altair.typing import ChannelColor as Color from altair.typing import ChannelOrder as Order from altair.typing import ChannelSize as Size @@ -25,23 +26,29 @@ else: from typing_extensions import Unpack - Encodings: TypeAlias = Dict[ - str, - Union[X, Y, Color, Order, Size, Tooltip], - ] + Encoding: TypeAlias = Union[X, Y, Color, Order, Size, Tooltip] + Encodings: TypeAlias = Dict[str, Encoding] + + +def _maybe_extract_shorthand(encoding: Encoding) -> Encoding: + if isinstance(encoding, alt.SchemaBase) and hasattr(encoding, "shorthand"): + # e.g. for `alt.X('x:Q', axis=alt.Axis(labelAngle=30))`, return `'x:Q'` + return encoding.shorthand + return encoding def _add_tooltip(encodings: Encodings, /, **kwargs: Unpack[EncodeKwds]) -> None: if "tooltip" not in kwargs: - encodings["tooltip"] = [*encodings.values(), *kwargs.values()] # type: ignore[assignment] + encodings["tooltip"] = [ + *[_maybe_extract_shorthand(x) for x in encodings.values()], + *[_maybe_extract_shorthand(x) for x in kwargs.values()], # type: ignore[arg-type] + ] # type: ignore[assignment] class DataFramePlot: """DataFrame.plot namespace.""" def __init__(self, df: DataFrame) -> None: - import altair as alt - self._chart = alt.Chart(df) def bar( diff --git a/py-polars/tests/unit/operations/namespaces/test_plot.py b/py-polars/tests/unit/operations/namespaces/test_plot.py index 5a4c1c21a596..789f2a0974d8 100644 --- a/py-polars/tests/unit/operations/namespaces/test_plot.py +++ b/py-polars/tests/unit/operations/namespaces/test_plot.py @@ -1,3 +1,5 @@ +import altair as alt + import polars as pl @@ -66,3 +68,9 @@ def test_empty_dataframe() -> None: def test_nameless_series() -> None: pl.Series([1, 2, 3]).plot.kde().to_json() + + +def test_x_with_axis_18830() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + result = df.plot.line(x=alt.X("a", axis=alt.Axis(labelAngle=-90))).to_dict() + assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}]