Skip to content

Commit

Permalink
fix: DataFrame plot was raising when some extra keywords were passed …
Browse files Browse the repository at this point in the history
…to encodings (e.g. `x=alt.X(a, axis=alt.Axis(labelAngle=30))`)
  • Loading branch information
MarcoGorelli committed Sep 20, 2024
1 parent 2b4986a commit 03c1c64
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
23 changes: 15 additions & 8 deletions py-polars/polars/dataframe/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import TYPE_CHECKING, Callable, Dict, Union

from polars.dependencies import altair as alt

if TYPE_CHECKING:
import sys

import altair as alt
from altair.typing import ChannelColor as Color
from altair.typing import ChannelOrder as Order
from altair.typing import ChannelSize as Size
Expand All @@ -25,23 +26,29 @@
else:
from typing_extensions import Unpack

Encodings: TypeAlias = Dict[
str,
Union[X, Y, Color, Order, Size, Tooltip],
]
Encoding: TypeAlias = Union[X, Y, Color, Order, Size, Tooltip]
Encodings: TypeAlias = Dict[str, Encoding]


def _maybe_extract_shorthand(encoding: Encoding) -> str:
if isinstance(encoding, alt.SchemaBase) and hasattr(encoding, "shorthand"):
# e.g. for `alt.X('x:Q', axis=alt.Axis(labelAngle=30))`, return `'x:Q'`
return encoding.shorthand
return encoding


def _add_tooltip(encodings: Encodings, /, **kwargs: Unpack[EncodeKwds]) -> None:
if "tooltip" not in kwargs:
encodings["tooltip"] = [*encodings.values(), *kwargs.values()] # type: ignore[assignment]
encodings["tooltip"] = [
*[_maybe_extract_shorthand(x) for x in encodings.values()],
*[_maybe_extract_shorthand(x) for x in kwargs.values()],
] # type: ignore[assignment]


class DataFramePlot:
"""DataFrame.plot namespace."""

def __init__(self, df: DataFrame) -> None:
import altair as alt

self._chart = alt.Chart(df)

def bar(
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/operations/namespaces/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import altair as alt

import polars as pl


Expand Down Expand Up @@ -66,3 +68,9 @@ def test_empty_dataframe() -> None:

def test_nameless_series() -> None:
pl.Series([1, 2, 3]).plot.kde().to_json()


def test_x_with_axis_18830() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
result = df.plot.line(x=alt.X("a", axis=alt.Axis(labelAngle=-90))).to_dict()
assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}]

0 comments on commit 03c1c64

Please sign in to comment.