diff --git a/altair/__init__.py b/altair/__init__.py index d4e20f02f..1d6b138ea 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -589,6 +589,7 @@ "YOffsetDatum", "YOffsetValue", "YValue", + "agg", "api", "binding", "binding_checkbox", @@ -607,6 +608,7 @@ "default_data_transformer", "display", "expr", + "field", "graticule", "hconcat", "jupyter", diff --git a/altair/vegalite/__init__.py b/altair/vegalite/__init__.py index 8fa78644e..d4582cb14 100644 --- a/altair/vegalite/__init__.py +++ b/altair/vegalite/__init__.py @@ -1,2 +1,4 @@ # ruff: noqa: F403 from .v5 import * +from .v5._api_rfc import agg as agg +from .v5._api_rfc import field as field diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py new file mode 100644 index 000000000..6a100387c --- /dev/null +++ b/altair/vegalite/v5/_api_rfc.py @@ -0,0 +1,440 @@ +""" +Request for comment on additions to `api.py`. + +Ideally these would be introduced *after* cleaning up the top-level namespace. + +Actual runtime dependencies: +- altair.utils.core +- altair.utils.schemapi + +The rest are to define aliases only. +""" + +# mypy: ignore-errors +from __future__ import annotations + +import datetime as dt +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING + +from altair.utils.core import TYPECODE_MAP as _TYPE_CODE +from altair.utils.core import parse_shorthand as _parse +from altair.utils.schemapi import SchemaBase, Undefined +from altair.vegalite.v5.schema import channels +from altair.vegalite.v5.schema.core import ( + FieldEqualPredicate, + FieldGTEPredicate, + FieldGTPredicate, + FieldLTEPredicate, + FieldLTPredicate, + FieldOneOfPredicate, + FieldRangePredicate, + FieldValidPredicate, +) + +if TYPE_CHECKING: + import sys + from typing import Any, Literal + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + from narwhals.typing import IntoDataFrame + + from altair.utils.schemapi import Optional + from altair.vegalite.v5.api import Parameter, _FieldEqualType + from altair.vegalite.v5.schema._typing import ( + AggregateOp_T, + BinnedTimeUnit_T, + Map, + MultiTimeUnit_T, + OneOrSeq, + SingleTimeUnit_T, + Temporal, + Type_T, + ) + from altair.vegalite.v5.schema.core import Predicate, TimeUnitParams + + +__all__ = ["agg", "field"] + +_Type: TypeAlias = 'Type_T | Literal["O", "N", "Q", "T", "G"] | None' +_TimeUnit: TypeAlias = "Optional[TimeUnitParams | SingleTimeUnit_T | MultiTimeUnit_T | BinnedTimeUnit_T | Map]" +_Range: TypeAlias = "Parameter | SchemaBase | Sequence[float | Temporal | Parameter | SchemaBase | Map | None] | Map" +_Value: TypeAlias = "str | float | Temporal | Parameter | SchemaBase | Map" +_OneOf: TypeAlias = "str | bool | float | Temporal | SchemaBase | Map" + + +_ENCODINGS = frozenset( + ( + "ordinal", + "O", + "nominal", + "N", + "quantitative", + "Q", + "temporal", + "T", + "geojson", + "G", + None, + ) +) + + +def _parse_aggregate( + aggregate: AggregateOp_T, name: str | None, encode_type: _Type, / +) -> dict[str, Any]: + if encode_type in _ENCODINGS: + enc = f":{_TYPE_CODE.get(s, s)}" if (s := encode_type) else "" + return _parse(f"{aggregate}({name or ''}){enc}") + else: + msg = ( + f"Expected a short/long-form encoding type, but got {encode_type!r}.\n\n" + f"Try passing one of the following to `type`:\n" + f"{', '.join(sorted(f'{e!r}' for e in _ENCODINGS))}." + ) + raise TypeError(msg) + + +def _one_of_flatten( + values: tuple[_OneOf, ...] | tuple[Sequence[_OneOf]] | tuple[Any, ...], / +) -> Sequence[_OneOf]: + if ( + len(values) == 1 + and not isinstance( + values[0], (str, bool, float, int, dt.date, Mapping, SchemaBase) + ) + and isinstance(values[0], Sequence) + ): + return values[0] + elif len(values) > 1: + return values + else: + msg = ( + f"Expected `values` to be either a single `Sequence` " + f"or used variadically, but got: {values!r}." + ) + raise TypeError(msg) + + +def _one_of_variance(val_1: Any, *rest: _OneOf) -> Sequence[Any]: + # Required that all elements are the same type + tp = type(val_1) + if all(isinstance(v, tp) for v in rest): + return (val_1, *rest) + else: + msg = ( + f"Expected all `values` to be of the same type, but got:\n" + f"{tuple(f'{type(v).__name__}' for v in (val_1, *rest))!r}" + ) + raise TypeError(msg) + + +class agg: + """ + Utility class providing autocomplete for shorthand. + + Functional alternative to shorthand mini-language. + """ + + def __new__( # type: ignore[misc] + cls, shorthand: dict[str, Any] | str, /, data: IntoDataFrame | None = None + ) -> dict[str, Any]: + return _parse(shorthand=shorthand, data=data) + + @classmethod + def argmin( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("argmin", col_name, type) + + @classmethod + def argmax( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("argmax", col_name, type) + + @classmethod + def average( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("average", col_name, type) + + @classmethod + def count(cls, col_name: str | None = None, /, type: _Type = "Q") -> dict[str, Any]: + return _parse_aggregate("count", col_name, type) + + @classmethod + def distinct( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("distinct", col_name, type) + + @classmethod + def max(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("max", col_name, type) + + @classmethod + def mean(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("mean", col_name, type) + + @classmethod + def median( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("median", col_name, type) + + @classmethod + def min(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("min", col_name, type) + + @classmethod + def missing( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("missing", col_name, type) + + @classmethod + def product( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("product", col_name, type) + + @classmethod + def q1(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("q1", col_name, type) + + @classmethod + def q3(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("q3", col_name, type) + + @classmethod + def ci0(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("ci0", col_name, type) + + @classmethod + def ci1(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("ci1", col_name, type) + + @classmethod + def stderr( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("stderr", col_name, type) + + @classmethod + def stdev( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("stdev", col_name, type) + + @classmethod + def stdevp( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("stdevp", col_name, type) + + @classmethod + def sum(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: + return _parse_aggregate("sum", col_name, type) + + @classmethod + def valid( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("valid", col_name, type) + + @classmethod + def values( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("values", col_name, type) + + @classmethod + def variance( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("variance", col_name, type) + + @classmethod + def variancep( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("variancep", col_name, type) + + @classmethod + def exponential( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("exponential", col_name, type) + + @classmethod + def exponentialb( + cls, col_name: str | None = None, /, type: _Type = None + ) -> dict[str, Any]: + return _parse_aggregate("exponentialb", col_name, type) + + +class field: + """ + Utility class for field predicates and shorthand parsing. + + Examples + -------- + >>> field("Origin") + {'field': 'Origin'} + + >>> field("Origin:N") + {'field': 'Origin', 'type': 'nominal'} + + >>> field.one_of("Origin", "Japan", "Europe") + FieldOneOfPredicate({ + field: 'Origin', + oneOf: ('Japan', 'Europe') + }) + """ + + def __new__( # type: ignore[misc] + cls, shorthand: dict[str, Any] | str, /, data: IntoDataFrame | None = None + ) -> dict[str, Any]: + return _parse(shorthand=shorthand, data=data) + + @classmethod + def one_of( + cls, field: str, /, *values: OneOrSeq[_OneOf], timeUnit: _TimeUnit = Undefined + ) -> Predicate: + seq = _one_of_flatten(values) + one_of = _one_of_variance(*seq) + return FieldOneOfPredicate(field=field, oneOf=one_of, timeUnit=timeUnit) + + @classmethod + def eq( + cls, field: str, value: _FieldEqualType, /, *, timeUnit: _TimeUnit = Undefined + ) -> Predicate: + if value is None: + return cls.valid(field, False, timeUnit=timeUnit) + return FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit) + + @classmethod + def lt( + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined + ) -> Predicate: + return FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit) + + @classmethod + def lte( + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined + ) -> Predicate: + return FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit) + + @classmethod + def gt( + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined + ) -> Predicate: + return FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit) + + @classmethod + def gte( + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined + ) -> Predicate: + return FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit) + + @classmethod + def valid( + cls, field: str, value: bool, /, *, timeUnit: _TimeUnit = Undefined + ) -> Predicate: + return FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit) + + @classmethod + def range( + cls, field: str, value: _Range, /, *, timeUnit: _TimeUnit = Undefined + ) -> Predicate: + return FieldRangePredicate(field=field, range=value, timeUnit=timeUnit) + + +# NOTE: Ignore everything below # +# ----------------------------- # + + +class field_into: + """ + Return wrapper for `agg`, `field` shorthand dicts. + + Idea + ---- + Rather than something like:: + + op_1 = alt.X(alt.agg.min("Cost", "Q")).scale(None) + + You could chain entirely from the agg:: + + # the internal unpacking will avoid double-checking the shorthand + op_2 = alt.agg.min("Cost", "Q").x().scale(None) + + Optionally, use the chained constructor:: + + op_2_1 = alt.agg.min("Cost", "Q").x(scale=None) + + + """ + + def __init__(self, arg: Map, /) -> None: + self._arg: Map = arg + + def angle(self, *args: Any, **kwds: Any) -> channels.Angle: ... + def color(self, *args: Any, **kwds: Any) -> channels.Color: ... + def column(self, *args: Any, **kwds: Any) -> channels.Column: ... + def description(self, *args: Any, **kwds: Any) -> channels.Description: ... + def detail(self, *args: Any, **kwds: Any) -> channels.Detail: ... + def facet(self, *args: Any, **kwds: Any) -> channels.Facet: ... + def fill(self, *args: Any, **kwds: Any) -> channels.Fill: ... + def fill_opacity(self, *args: Any, **kwds: Any) -> channels.FillOpacity: ... + def href(self, *args: Any, **kwds: Any) -> channels.Href: ... + def key(self, *args: Any, **kwds: Any) -> channels.Key: ... + def latitude(self, *args: Any, **kwds: Any) -> channels.Latitude: ... + def latitude2(self, *args: Any, **kwds: Any) -> channels.Latitude2: ... + def longitude(self, *args: Any, **kwds: Any) -> channels.Longitude: ... + def longitude2(self, *args: Any, **kwds: Any) -> channels.Longitude2: ... + def opacity(self, *args: Any, **kwds: Any) -> channels.Opacity: ... + def order(self, *args: Any, **kwds: Any) -> channels.Order: ... + def radius(self, *args: Any, **kwds: Any) -> channels.Radius: ... + def radius2(self, *args: Any, **kwds: Any) -> channels.Radius2: ... + def row(self, *args: Any, **kwds: Any) -> channels.Row: ... + def shape(self, *args: Any, **kwds: Any) -> channels.Shape: ... + def size(self, *args: Any, **kwds: Any) -> channels.Size: ... + def stroke(self, *args: Any, **kwds: Any) -> channels.Stroke: ... + def stroke_dash(self, *args: Any, **kwds: Any) -> channels.StrokeDash: ... + def stroke_opacity(self, *args: Any, **kwds: Any) -> channels.StrokeOpacity: ... + def stroke_width(self, *args: Any, **kwds: Any) -> channels.StrokeWidth: ... + def text(self, *args: Any, **kwds: Any) -> channels.Text: ... + def theta(self, *args: Any, **kwds: Any) -> channels.Theta: ... + def theta2(self, *args: Any, **kwds: Any) -> channels.Theta2: ... + def tooltip(self, *args: Any, **kwds: Any) -> channels.Tooltip: ... + def url(self, *args: Any, **kwds: Any) -> channels.Url: ... + def x(self, *args: Any, **kwds: Any) -> channels.X: + return channels.X(*args, **self._arg, **kwds) + + def x2(self, *args: Any, **kwds: Any) -> channels.X2: + return channels.X2(*args, **self._arg, **kwds) + + def x_error(self, *args: Any, **kwds: Any) -> channels.XError: ... + def x_error2(self, *args: Any, **kwds: Any) -> channels.XError2: ... + def x_offset(self, *args: Any, **kwds: Any) -> channels.XOffset: ... + def y(self, *args: Any, **kwds: Any) -> channels.Y: + return channels.Y(*args, **self._arg, **kwds) + + def y2(self, *args: Any, **kwds: Any) -> channels.Y2: + return channels.Y2(*args, **self._arg, **kwds) + + def y_error(self, *args: Any, **kwds: Any) -> channels.YError: ... + def y_error2(self, *args: Any, **kwds: Any) -> channels.YError2: ... + def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffset: ... + + +def example_field(): + field_out = agg.q1() + wrapped = field_into(field_out).x() # noqa: F841 + some_field = field_into(agg.min("Cost", "Q")) + beeeee = some_field.x().scale(None).impute(None).axis(None) # noqa: F841 diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index 3562bc83e..15022278d 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -188,7 +188,7 @@ def to_dict( if shorthand is Undefined: parsed = {} - elif isinstance(shorthand, str): + elif isinstance(shorthand, (str, dict)): data: nw.DataFrame | Any = context.get("data", None) parsed = parse_shorthand(shorthand, data=data) type_required = "type" in self._kwds # type: ignore[attr-defined] diff --git a/tests/vegalite/v5/test__api_rfc.py b/tests/vegalite/v5/test__api_rfc.py new file mode 100644 index 000000000..de4f7f9af --- /dev/null +++ b/tests/vegalite/v5/test__api_rfc.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +# ruff: noqa: F401 +import re +from typing import TYPE_CHECKING + +import pytest + +import altair as alt +from altair.utils.core import INV_TYPECODE_MAP, TYPECODE_MAP +from altair.vegalite.v5._api_rfc import _Type, agg, field + +if TYPE_CHECKING: + from altair.vegalite.v5.schema._typing import AggregateOp_T + + +def test_agg_type_invalid() -> None: + with pytest.raises( + TypeError, match=re.compile(r"'bogus'.+Try.+'quantitative'", re.DOTALL) + ): + agg.count(type="bogus") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "enc_type", + [ + "quantitative", + "ordinal", + "temporal", + "nominal", + "geojson", + "O", + "N", + "Q", + "T", + "G", + None, + ], +) +@pytest.mark.parametrize("col_name", ["column_1", None]) +@pytest.mark.parametrize( + "method_name", + [ + "argmax", + "argmin", + "average", + "count", + "distinct", + "max", + "mean", + "median", + "min", + "missing", + "product", + "q1", + "q3", + "ci0", + "ci1", + "stderr", + "stdev", + "stdevp", + "sum", + "valid", + "values", + "variance", + "variancep", + "exponential", + "exponentialb", + ], +) +def test_agg_methods(method_name: AggregateOp_T, col_name: str | None, enc_type: _Type): + actual = getattr(agg, method_name)(col_name, enc_type) + assert isinstance(actual, dict) + assert actual["aggregate"] == method_name + if col_name: + assert actual["field"] == col_name + if enc_type: + assert actual["type"] == INV_TYPECODE_MAP.get(enc_type, enc_type) + + +def test_field_one_of_covariant() -> None: + with pytest.raises(TypeError, match=re.compile(r"Expected.+same type", re.DOTALL)): + field.one_of("field 1", 5, 6, 7, "nineteen", 8000.4) + + +def test_field_one_of_variadic(): + args = "A", "B", "C", "D", "E" + assert field.one_of("field_1", *args) == field.one_of("field_1", args) + + +def test_field_wrap(): + comp = field.eq("field 1", 10) + assert isinstance(comp, alt.SelectionPredicateComposition) + + +def test_field_compose(): + from vega_datasets import data + + cars_select = field.one_of("Origin", "Japan", "Europe") | field.range( + "Miles_per_Gallon", (25, 40) + ) + assert isinstance(cars_select, alt.SelectionPredicateComposition) + + source = data.cars() + chart = ( + alt.Chart(source) + .mark_point() + .encode( + x="Horsepower", + y="Miles_per_Gallon", + color=alt.condition(cars_select, alt.value("red"), alt.value("grey")), + ) + ) + chart.to_dict() diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 70c3980e4..4453e64bb 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -132,7 +132,7 @@ def to_dict( if shorthand is Undefined: parsed = {} - elif isinstance(shorthand, str): + elif isinstance(shorthand, (str, dict)): data: nw.DataFrame | Any = context.get("data", None) parsed = parse_shorthand(shorthand, data=data) type_required = "type" in self._kwds # type: ignore[attr-defined]