From fe7e77ea546846d1e279ab569eaca1628b49153a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 09:46:36 +0100 Subject: [PATCH 01/19] feat(RFC): Adds `agg`, `field` utility classes `field` proposed in https://github.com/vega/altair/issues/3239#issuecomment-2233593364 `agg` was developed during https://github.com/vega/altair/pull/3427#issuecomment-2228454241 as a solution to part of https://github.com/vega/altair/discussions/3476 --- altair/vegalite/v5/_api_rfc.py | 344 +++++++++++++++++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 altair/vegalite/v5/_api_rfc.py diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py new file mode 100644 index 000000000..236192ae8 --- /dev/null +++ b/altair/vegalite/v5/_api_rfc.py @@ -0,0 +1,344 @@ +""" +Request for comment on additions to `api.py`. + +Ideally these would be introduced *after* cleaning up the top-level namespace. + +Actual runtime dependencies: +- altair.utils.core +- altair.utils.schemapi + +The rest are to define aliases only. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Union + +from typing_extensions import TypeAlias + +from altair.utils.core import TYPECODE_MAP as _TYPE_CODE +from altair.utils.core import parse_shorthand as _parse +from altair.utils.schemapi import Optional, SchemaBase, Undefined +from altair.vegalite.v5.api import Parameter, SelectionPredicateComposition +from altair.vegalite.v5.schema._typing import ( + BinnedTimeUnit_T, + MultiTimeUnit_T, + SingleTimeUnit_T, + Type_T, +) +from altair.vegalite.v5.schema.core import ( + FieldEqualPredicate, + FieldGTEPredicate, + FieldGTPredicate, + FieldLTEPredicate, + FieldLTPredicate, + FieldOneOfPredicate, + FieldRangePredicate, + FieldValidPredicate, +) + +if TYPE_CHECKING: + from altair.utils.core import DataFrameLike + from altair.vegalite.v5.schema._typing import AggregateOp_T + from altair.vegalite.v5.schema.core import Predicate + +__all__ = ["agg", "field"] + +EncodeType: TypeAlias = Union[Type_T, Literal["O", "N", "Q", "T", "G"], None] +AnyTimeUnit: TypeAlias = Union[MultiTimeUnit_T, BinnedTimeUnit_T, SingleTimeUnit_T] +TimeUnitType: TypeAlias = Optional[Union[Dict[str, Any], SchemaBase, AnyTimeUnit]] +RangeType: TypeAlias = Union[ + Dict[str, Any], + Parameter, + SchemaBase, + Sequence[Union[Dict[str, Any], None, float, Parameter, SchemaBase]], +] +ValueType: TypeAlias = Union[str, bool, float, Dict[str, Any], Parameter, SchemaBase] + + +_ENCODINGS = frozenset( + ( + "ordinal", + "O", + "nominal", + "N", + "quantitative", + "Q", + "temporal", + "T", + "geojson", + "G", + None, + ) +) + + +def _parse_aggregate( + aggregate: AggregateOp_T, name: str | None, encode_type: EncodeType, / +) -> dict[str, Any]: + if encode_type in _ENCODINGS: + enc = f":{_TYPE_CODE.get(s, s)}" if (s := encode_type) else "" + return _parse(f"{aggregate}({name or ''}){enc}") + else: + msg = ( + f"Expected a short/long-form encoding type, but got {encode_type!r}.\n\n" + f"Try passing one of the following to `type`:\n" + f"{', '.join(sorted(f'{e!r}' for e in _ENCODINGS))}." + ) + raise TypeError(msg) + + +def _wrap_composition(predicate: Predicate, /) -> SelectionPredicateComposition: + return SelectionPredicateComposition(predicate.to_dict()) + + +class agg: + """Utility class providing autocomplete for shorthand. + + Functional alternative to shorthand mini-language. + """ + + def __new__( # type: ignore[misc] + cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None + ) -> dict[str, Any]: + return _parse(shorthand=shorthand, data=data) + + @classmethod + def argmin( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("argmin", col_name, type) + + @classmethod + def argmax( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("argmax", col_name, type) + + @classmethod + def average( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("average", col_name, type) + + @classmethod + def count( + cls, col_name: str | None = None, /, type: EncodeType = "Q" + ) -> dict[str, Any]: + return _parse_aggregate("count", col_name, type) + + @classmethod + def distinct( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("distinct", col_name, type) + + @classmethod + def max( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("max", col_name, type) + + @classmethod + def mean( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("mean", col_name, type) + + @classmethod + def median( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("median", col_name, type) + + @classmethod + def min( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("min", col_name, type) + + @classmethod + def missing( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("missing", col_name, type) + + @classmethod + def product( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("product", col_name, type) + + @classmethod + def q1( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("q1", col_name, type) + + @classmethod + def q3( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("q3", col_name, type) + + @classmethod + def ci0( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("ci0", col_name, type) + + @classmethod + def ci1( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("ci1", col_name, type) + + @classmethod + def stderr( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stderr", col_name, type) + + @classmethod + def stdev( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stdev", col_name, type) + + @classmethod + def stdevp( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stdevp", col_name, type) + + @classmethod + def sum( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("sum", col_name, type) + + @classmethod + def valid( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("valid", col_name, type) + + @classmethod + def values( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("values", col_name, type) + + @classmethod + def variance( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("variance", col_name, type) + + @classmethod + def variancep( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("variancep", col_name, type) + + @classmethod + def exponential( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("exponential", col_name, type) + + @classmethod + def exponentialb( + cls, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("exponentialb", col_name, type) + + +class field: + """Utility class for field predicates and shorthand parsing. + + Examples + -------- + >>> field("Origin") + {'field': 'Origin'} + + >>> field("Origin:N") + {'field': 'Origin', 'type': 'nominal'} + + >>> field.one_of("Origin", "Japan", "Europe") + SelectionPredicateComposition({'field': 'Origin', 'oneOf': ['Japan', 'Europe']}) + """ + + def __new__( # type: ignore[misc] + cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None + ) -> dict[str, Any]: + return _parse(shorthand=shorthand, data=data) + + @classmethod + def one_of( + cls, + field: str, + /, + *values: bool | float | dict[str, Any] | SchemaBase, + timeUnit: TimeUnitType = Undefined, + ) -> SelectionPredicateComposition: + tp: type[Any] = type(values[0]) + if all(isinstance(v, tp) for v in values): + vals: Sequence[Any] = values + p = FieldOneOfPredicate(field=field, oneOf=vals, timeUnit=timeUnit) + return _wrap_composition(p) + else: + msg = ( + f"Expected all `values` to be of the same type, but got:\n" + f"{tuple(f"{type(v).__name__}" for v in values)!r}" + ) + raise TypeError(msg) + + @classmethod + def eq( + cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + ) -> SelectionPredicateComposition: + p = FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit) + return _wrap_composition(p) + + @classmethod + def lt( + cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + ) -> SelectionPredicateComposition: + p = FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit) + return _wrap_composition(p) + + @classmethod + def lte( + cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + ) -> SelectionPredicateComposition: + p = FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit) + return _wrap_composition(p) + + @classmethod + def gt( + cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + ) -> SelectionPredicateComposition: + p = FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit) + return _wrap_composition(p) + + @classmethod + def gte( + cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + ) -> SelectionPredicateComposition: + p = FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit) + return _wrap_composition(p) + + @classmethod + def valid( + cls, field: str, value: bool, /, *, timeUnit: TimeUnitType = Undefined + ) -> SelectionPredicateComposition: + p = FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit) + return _wrap_composition(p) + + @classmethod + def range( + cls, field: str, value: RangeType, /, *, timeUnit: TimeUnitType = Undefined + ) -> SelectionPredicateComposition: + p = FieldRangePredicate(field=field, range=value, timeUnit=timeUnit) + return _wrap_composition(p) From c418066476d5d1c4c2de120cde249557b09e61eb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 09:47:06 +0100 Subject: [PATCH 02/19] test: add tests for `agg`, `field` --- tests/vegalite/v5/test__api_rfc.py | 116 +++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/vegalite/v5/test__api_rfc.py diff --git a/tests/vegalite/v5/test__api_rfc.py b/tests/vegalite/v5/test__api_rfc.py new file mode 100644 index 000000000..703bb8b25 --- /dev/null +++ b/tests/vegalite/v5/test__api_rfc.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +# ruff: noqa: F401 +import re +import pytest + +import altair as alt + +from altair.vegalite.v5._api_rfc import agg, field, EncodeType +from altair.utils.core import TYPECODE_MAP, INV_TYPECODE_MAP + +if TYPE_CHECKING: + from altair.vegalite.v5.schema._typing import AggregateOp_T + + +def test_fail_shorthand() -> None: + with pytest.raises( + TypeError, match=re.compile(r"'bogus'.+Try.+'quantitative'", re.DOTALL) + ): + agg.count(type="bogus") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "enc_type", + [ + "quantitative", + "ordinal", + "temporal", + "nominal", + "geojson", + "O", + "N", + "Q", + "T", + "G", + None, + ], +) +@pytest.mark.parametrize("col_name", ["column_1", None]) +@pytest.mark.parametrize( + "method_name", + [ + "argmax", + "argmin", + "average", + "count", + "distinct", + "max", + "mean", + "median", + "min", + "missing", + "product", + "q1", + "q3", + "ci0", + "ci1", + "stderr", + "stdev", + "stdevp", + "sum", + "valid", + "values", + "variance", + "variancep", + "exponential", + "exponentialb", + ], +) +def test_passing_shorthand( + method_name: AggregateOp_T, col_name: str | None, enc_type: EncodeType +): + actual = getattr(agg, method_name)(col_name, enc_type) + assert isinstance(actual, dict) + assert actual["aggregate"] == method_name + if col_name: + assert actual["field"] == col_name + if enc_type: + assert actual["type"] == INV_TYPECODE_MAP.get(enc_type, enc_type) + + +def test_fail_one_of() -> None: + with pytest.raises(TypeError, match=re.compile(r"Expected.+same type", re.DOTALL)): + field.one_of("field 1", 5, 6, 7, "nineteen", 8000.4) # type: ignore[arg-type] + + +def test_examples_field() -> None: + print(field("Origin")) + + +def test_compose_field(): + comp = field.eq("field 1", 10) + assert isinstance(comp, alt.SelectionPredicateComposition) + + +def test_compose_predicates(): + from vega_datasets import data + + cars_select = field.one_of("Origin", "Japan", "Europe") | field.range( + "Miles_per_Gallon", (25, 40) + ) + assert isinstance(cars_select, alt.SelectionPredicateComposition) + + source = data.cars() + chart = ( + alt.Chart(source) + .mark_point() + .encode( + x="Horsepower", + y="Miles_per_Gallon", + color=alt.condition(cars_select, alt.value("red"), alt.value("grey")), + ) + ) + chart.to_dict() From c92ede0052d2a80042466d29a85a09a3ce9089cf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 09:48:30 +0100 Subject: [PATCH 03/19] style(ruff): Format docstrings --- altair/vegalite/v5/_api_rfc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 236192ae8..4d83d386b 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -93,7 +93,8 @@ def _wrap_composition(predicate: Predicate, /) -> SelectionPredicateComposition: class agg: - """Utility class providing autocomplete for shorthand. + """ + Utility class providing autocomplete for shorthand. Functional alternative to shorthand mini-language. """ @@ -255,7 +256,8 @@ def exponentialb( class field: - """Utility class for field predicates and shorthand parsing. + """ + Utility class for field predicates and shorthand parsing. Examples -------- From 662a13773a9c8a0862f4439b132ee585dfae0840 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 11:31:33 +0100 Subject: [PATCH 04/19] fix(typing): Add missing `str` for `field.one_of` --- altair/vegalite/v5/_api_rfc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 4d83d386b..03be208d2 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -281,7 +281,7 @@ def one_of( cls, field: str, /, - *values: bool | float | dict[str, Any] | SchemaBase, + *values: str | bool | float | dict[str, Any] | SchemaBase, timeUnit: TimeUnitType = Undefined, ) -> SelectionPredicateComposition: tp: type[Any] = type(values[0]) From 35360b7e659893a09ff9abe34c34e362881d8d1a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 11:35:16 +0100 Subject: [PATCH 05/19] refactor(typing): Add `OneOfType` alias --- altair/vegalite/v5/_api_rfc.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 03be208d2..b7fe02079 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -54,6 +54,7 @@ Sequence[Union[Dict[str, Any], None, float, Parameter, SchemaBase]], ] ValueType: TypeAlias = Union[str, bool, float, Dict[str, Any], Parameter, SchemaBase] +OneOfType: TypeAlias = Union[str, bool, float, Dict[str, Any], SchemaBase] _ENCODINGS = frozenset( @@ -278,11 +279,7 @@ def __new__( # type: ignore[misc] @classmethod def one_of( - cls, - field: str, - /, - *values: str | bool | float | dict[str, Any] | SchemaBase, - timeUnit: TimeUnitType = Undefined, + cls, field: str, /, *values: OneOfType, timeUnit: TimeUnitType = Undefined ) -> SelectionPredicateComposition: tp: type[Any] = type(values[0]) if all(isinstance(v, tp) for v in values): From 7a4451ad68039a7d105b41ecd1aa709d3a6eb413 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:31:08 +0100 Subject: [PATCH 06/19] feat: Support non-variadic input to `field.one_of` https://github.com/vega/altair/issues/3239#issuecomment-2234057768 --- altair/vegalite/v5/_api_rfc.py | 55 ++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index b7fe02079..3c62ad57a 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -12,7 +12,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Sequence, Union from typing_extensions import TypeAlias @@ -93,6 +93,38 @@ def _wrap_composition(predicate: Predicate, /) -> SelectionPredicateComposition: return SelectionPredicateComposition(predicate.to_dict()) +def _one_of_flatten( + values: tuple[OneOfType, ...] | tuple[Sequence[OneOfType]] | tuple[Any, ...], / +) -> Sequence[OneOfType]: + if ( + len(values) == 1 + and not isinstance(values[0], (str, bool, float, int, Mapping, SchemaBase)) + and isinstance(values[0], Sequence) + ): + return values[0] + elif len(values) > 1: + return values + else: + msg = ( + f"Expected `values` to be either a single `Sequence` " + f"or used variadically, but got: {values!r}." + ) + raise TypeError(msg) + + +def _one_of_variance(val_1: Any, *rest: OneOfType) -> Sequence[Any]: + # Required that all elements are the same type + tp = type(val_1) + if all(isinstance(v, tp) for v in rest): + return (val_1, *rest) + else: + msg = ( + f"Expected all `values` to be of the same type, but got:\n" + f"{tuple(f'{type(v).__name__}' for v in (val_1, *rest))!r}" + ) + raise TypeError(msg) + + class agg: """ Utility class providing autocomplete for shorthand. @@ -279,19 +311,16 @@ def __new__( # type: ignore[misc] @classmethod def one_of( - cls, field: str, /, *values: OneOfType, timeUnit: TimeUnitType = Undefined + cls, + field: str, + /, + *values: OneOfType | Sequence[OneOfType], + timeUnit: TimeUnitType = Undefined, ) -> SelectionPredicateComposition: - tp: type[Any] = type(values[0]) - if all(isinstance(v, tp) for v in values): - vals: Sequence[Any] = values - p = FieldOneOfPredicate(field=field, oneOf=vals, timeUnit=timeUnit) - return _wrap_composition(p) - else: - msg = ( - f"Expected all `values` to be of the same type, but got:\n" - f"{tuple(f"{type(v).__name__}" for v in values)!r}" - ) - raise TypeError(msg) + seq = _one_of_flatten(values) + one_of = _one_of_variance(*seq) + p = FieldOneOfPredicate(field=field, oneOf=one_of, timeUnit=timeUnit) + return _wrap_composition(p) @classmethod def eq( From 951f2a150db2b97c69b235c7b46b290adce54233 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:45:49 +0100 Subject: [PATCH 07/19] test: Update/ add tests for `agg`, `field` Mostly renaming, but also added `test_field_one_of_variadic` --- tests/vegalite/v5/test__api_rfc.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/vegalite/v5/test__api_rfc.py b/tests/vegalite/v5/test__api_rfc.py index 703bb8b25..8042fd98d 100644 --- a/tests/vegalite/v5/test__api_rfc.py +++ b/tests/vegalite/v5/test__api_rfc.py @@ -15,7 +15,7 @@ from altair.vegalite.v5.schema._typing import AggregateOp_T -def test_fail_shorthand() -> None: +def test_agg_type_invalid() -> None: with pytest.raises( TypeError, match=re.compile(r"'bogus'.+Try.+'quantitative'", re.DOTALL) ): @@ -69,7 +69,7 @@ def test_fail_shorthand() -> None: "exponentialb", ], ) -def test_passing_shorthand( +def test_agg_methods( method_name: AggregateOp_T, col_name: str | None, enc_type: EncodeType ): actual = getattr(agg, method_name)(col_name, enc_type) @@ -81,21 +81,22 @@ def test_passing_shorthand( assert actual["type"] == INV_TYPECODE_MAP.get(enc_type, enc_type) -def test_fail_one_of() -> None: +def test_field_one_of_covariant() -> None: with pytest.raises(TypeError, match=re.compile(r"Expected.+same type", re.DOTALL)): - field.one_of("field 1", 5, 6, 7, "nineteen", 8000.4) # type: ignore[arg-type] + field.one_of("field 1", 5, 6, 7, "nineteen", 8000.4) -def test_examples_field() -> None: - print(field("Origin")) +def test_field_one_of_variadic(): + args = "A", "B", "C", "D", "E" + assert field.one_of("field_1", *args) == field.one_of("field_1", args) -def test_compose_field(): +def test_field_wrap(): comp = field.eq("field 1", 10) assert isinstance(comp, alt.SelectionPredicateComposition) -def test_compose_predicates(): +def test_field_compose(): from vega_datasets import data cars_select = field.one_of("Origin", "Japan", "Europe") | field.range( From 766d1bf793ce7ddb55469ec765f2bafae74d4a05 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:56:05 +0100 Subject: [PATCH 08/19] chore: Add `agg`, `field` to `__all__` Purely for visual demonstration --- altair/__init__.py | 2 ++ altair/vegalite/__init__.py | 1 + 2 files changed, 3 insertions(+) diff --git a/altair/__init__.py b/altair/__init__.py index d6c03f48a..ed54e4606 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -591,6 +591,7 @@ "YOffsetDatum", "YOffsetValue", "YValue", + "agg", "api", "binding", "binding_checkbox", @@ -609,6 +610,7 @@ "default_data_transformer", "display", "expr", + "field", "graticule", "hconcat", "is_chart_type", diff --git a/altair/vegalite/__init__.py b/altair/vegalite/__init__.py index 8fa78644e..7833afcee 100644 --- a/altair/vegalite/__init__.py +++ b/altair/vegalite/__init__.py @@ -1,2 +1,3 @@ # ruff: noqa: F403 from .v5 import * +from .v5._api_rfc import agg as agg, field as field From 5929a5e550f8747a296c801fb1979fd59e1fa3db Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Jul 2024 17:08:24 +0100 Subject: [PATCH 09/19] fix: Support already parsed shorthand in `FieldChannelMixin` Ensures the same output from ```py alt.Y("count()") alt.Y(alt.agg.count()) ``` --- altair/vegalite/v5/schema/channels.py | 2 +- tools/generate_schema_wrapper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index 530af95ee..0d14bc3ac 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -169,7 +169,7 @@ def to_dict( if shorthand is Undefined: parsed = {} - elif isinstance(shorthand, str): + elif isinstance(shorthand, (str, dict)): parsed = parse_shorthand(shorthand, data=context.get("data", None)) type_required = "type" in self._kwds # type: ignore[attr-defined] type_in_shorthand = "type" in parsed diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 094cd66cc..27bddb525 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -99,7 +99,7 @@ def to_dict( if shorthand is Undefined: parsed = {} - elif isinstance(shorthand, str): + elif isinstance(shorthand, (str, dict)): parsed = parse_shorthand(shorthand, data=context.get("data", None)) type_required = "type" in self._kwds # type: ignore[attr-defined] type_in_shorthand = "type" in parsed From 208ac8beb771464e69cc46db7e353486530a62ad Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:57:13 +0100 Subject: [PATCH 10/19] wip --- altair/vegalite/v5/_api_rfc.py | 266 +++++++++++++++++++++++++++++++++ 1 file changed, 266 insertions(+) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 3c62ad57a..49a0a3826 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -41,6 +41,7 @@ from altair.utils.core import DataFrameLike from altair.vegalite.v5.schema._typing import AggregateOp_T from altair.vegalite.v5.schema.core import Predicate + from altair.vegalite.v5.schema import channels __all__ = ["agg", "field"] @@ -125,6 +126,138 @@ def _one_of_variance(val_1: Any, *rest: OneOfType) -> Sequence[Any]: raise TypeError(msg) +class _FieldMeta(type): + def __new__( # type: ignore[misc] + cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None + ) -> dict[str, Any]: + return _parse(shorthand=shorthand, data=data) + + def argmin( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("argmin", col_name, type) + + def argmax( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("argmax", col_name, type) + + def average( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("average", col_name, type) + + def count( + self, col_name: str | None = None, /, type: EncodeType = "Q" + ) -> dict[str, Any]: + return _parse_aggregate("count", col_name, type) + + def distinct( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("distinct", col_name, type) + + def max( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("max", col_name, type) + + def mean( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("mean", col_name, type) + + def median( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("median", col_name, type) + + def min( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("min", col_name, type) + + def missing( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("missing", col_name, type) + + def product( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("product", col_name, type) + + def q1( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("q1", col_name, type) + + def q3( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("q3", col_name, type) + + def ci0( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("ci0", col_name, type) + + def ci1( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("ci1", col_name, type) + + def stderr( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stderr", col_name, type) + + def stdev( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stdev", col_name, type) + + def stdevp( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stdevp", col_name, type) + + def sum( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("sum", col_name, type) + + def valid( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("valid", col_name, type) + + def values( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("values", col_name, type) + + def variance( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("variance", col_name, type) + + def variancep( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("variancep", col_name, type) + + def exponential( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("exponential", col_name, type) + + def exponentialb( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("exponentialb", col_name, type) + + class agg: """ Utility class providing autocomplete for shorthand. @@ -370,3 +503,136 @@ def range( ) -> SelectionPredicateComposition: p = FieldRangePredicate(field=field, range=value, timeUnit=timeUnit) return _wrap_composition(p) + + +class field2(metaclass=_FieldMeta): + def angle(self, *args: Any, **kwds: Any) -> channels.Angle: ... + def color(self, *args: Any, **kwds: Any) -> channels.Color: ... + def column(self, *args: Any, **kwds: Any) -> Any: ... + def description(self, *args: Any, **kwds: Any) -> Any: ... + def detail(self, *args: Any, **kwds: Any) -> Any: ... + def facet(self, *args: Any, **kwds: Any) -> Any: ... + def fill(self, *args: Any, **kwds: Any) -> Any: ... + def fill_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def href(self, *args: Any, **kwds: Any) -> Any: ... + def key(self, *args: Any, **kwds: Any) -> Any: ... + def latitude(self, *args: Any, **kwds: Any) -> Any: ... + def latitude_2(self, *args: Any, **kwds: Any) -> Any: ... + def longitude(self, *args: Any, **kwds: Any) -> Any: ... + def longitude_2(self, *args: Any, **kwds: Any) -> Any: ... + def opacity(self, *args: Any, **kwds: Any) -> Any: ... + def order(self, *args: Any, **kwds: Any) -> Any: ... + def radius(self, *args: Any, **kwds: Any) -> Any: ... + def radius_2(self, *args: Any, **kwds: Any) -> Any: ... + def row(self, *args: Any, **kwds: Any) -> Any: ... + def shape(self, *args: Any, **kwds: Any) -> Any: ... + def size(self, *args: Any, **kwds: Any) -> Any: ... + def stroke(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_dash(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_width(self, *args: Any, **kwds: Any) -> Any: ... + def text(self, *args: Any, **kwds: Any) -> Any: ... + def theta(self, *args: Any, **kwds: Any) -> Any: ... + def theta_2(self, *args: Any, **kwds: Any) -> Any: ... + def tooltip(self, *args: Any, **kwds: Any) -> Any: ... + def url(self, *args: Any, **kwds: Any) -> Any: ... + def x(self, *args: Any, **kwds: Any) -> channels.X: ... + def x_2(self, *args: Any, **kwds: Any) -> channels.X2: ... + def x_error(self, *args: Any, **kwds: Any) -> channels.XError: ... + def x_error_2(self, *args: Any, **kwds: Any) -> channels.XError2: ... + def x_offset(self, *args: Any, **kwds: Any) -> channels.XOffset: ... + def y(self, *args: Any, **kwds: Any) -> channels.Y: ... + def y_2(self, *args: Any, **kwds: Any) -> channels.Y2: ... + def y_error(self, *args: Any, **kwds: Any) -> channels.YError: ... + def y_error_2(self, *args: Any, **kwds: Any) -> channels.YError2: ... + def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffset: ... + @property + def value(self): + return _ChannelValueNamespace(self) + + @property + def datum(self): + return _ChannelDatumNamespace(self) + + +class _ChannelValueNamespace: + def __init__(self, arg: Any, /) -> None: + self._arg = arg + + def angle(self, *args: Any, **kwds: Any) -> Any: ... + def color(self, *args: Any, **kwds: Any) -> Any: ... + def description(self, *args: Any, **kwds: Any) -> Any: ... + def fill(self, *args: Any, **kwds: Any) -> Any: ... + def fill_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def href(self, *args: Any, **kwds: Any) -> Any: ... + def latitude_2(self, *args: Any, **kwds: Any) -> Any: ... + def longitude_2(self, *args: Any, **kwds: Any) -> Any: ... + def opacity(self, *args: Any, **kwds: Any) -> Any: ... + def order(self, *args: Any, **kwds: Any) -> Any: ... + def radius(self, *args: Any, **kwds: Any) -> Any: ... + def radius_2(self, *args: Any, **kwds: Any) -> Any: ... + def shape(self, *args: Any, **kwds: Any) -> Any: ... + def size(self, *args: Any, **kwds: Any) -> Any: ... + def stroke(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_dash(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_width(self, *args: Any, **kwds: Any) -> Any: ... + def text(self, *args: Any, **kwds: Any) -> Any: ... + def theta(self, *args: Any, **kwds: Any) -> Any: ... + def theta_2(self, *args: Any, **kwds: Any) -> Any: ... + def tooltip(self, *args: Any, **kwds: Any) -> Any: ... + def url(self, *args: Any, **kwds: Any) -> Any: ... + def x(self, *args: Any, **kwds: Any) -> channels.XValue: ... + def x_2(self, *args: Any, **kwds: Any) -> channels.X2Value: ... + def x_error(self, *args: Any, **kwds: Any) -> channels.XErrorValue: ... + def x_error_2(self, *args: Any, **kwds: Any) -> channels.XError2Value: ... + def x_offset(self, *args: Any, **kwds: Any) -> channels.XOffsetValue: ... + def y(self, *args: Any, **kwds: Any) -> channels.YValue: ... + def y_2(self, *args: Any, **kwds: Any) -> channels.Y2Value: ... + def y_error(self, *args: Any, **kwds: Any) -> channels.YErrorValue: ... + def y_error_2(self, *args: Any, **kwds: Any) -> channels.YError2Value: ... + def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffsetValue: ... + + +class _ChannelDatumNamespace: + def __init__(self, arg: Any, /) -> None: + self._arg = arg + + def angle(self, *args: Any, **kwds: Any) -> Any: ... + def color(self, *args: Any, **kwds: Any) -> Any: ... + def fill(self, *args: Any, **kwds: Any) -> Any: ... + def fill_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def latitude(self, *args: Any, **kwds: Any) -> Any: ... + def latitude_2(self, *args: Any, **kwds: Any) -> Any: ... + def longitude(self, *args: Any, **kwds: Any) -> Any: ... + def longitude_2(self, *args: Any, **kwds: Any) -> Any: ... + def opacity(self, *args: Any, **kwds: Any) -> Any: ... + def radius(self, *args: Any, **kwds: Any) -> Any: ... + def radius_2(self, *args: Any, **kwds: Any) -> Any: ... + def shape(self, *args: Any, **kwds: Any) -> Any: ... + def size(self, *args: Any, **kwds: Any) -> Any: ... + def stroke(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_dash(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_width(self, *args: Any, **kwds: Any) -> Any: ... + def text(self, *args: Any, **kwds: Any) -> Any: ... + def theta(self, *args: Any, **kwds: Any) -> Any: ... + def theta_2(self, *args: Any, **kwds: Any) -> Any: ... + def x(self, *args: Any, **kwds: Any) -> channels.XDatum: ... + def x_2(self, *args: Any, **kwds: Any) -> channels.X2Datum: ... + def x_offset(self, *args: Any, **kwds: Any) -> channels.XOffsetDatum: ... + def y(self, *args: Any, **kwds: Any) -> channels.YDatum: ... + def y_2(self, *args: Any, **kwds: Any) -> channels.Y2Datum: ... + def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffsetDatum: ... + + +abcd = field2.q1() +efg = field2.angle() +# some_field = field2.min("Cost", "Q") +some_field = field2("Cost:Q") +beeeee = some_field.x() +cee = some_field.value.x() +deee = some_field.datum.x() + + +ffff = field2("Cost:Q").x_2().bandPosition(4.7) From 07dbcc2ca90fd7c677e4150c5e4cfbed433b2e5b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:02:28 +0100 Subject: [PATCH 11/19] fix: botched commit --- altair/vegalite/v5/_api_rfc.py | 252 +++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 3c62ad57a..90fc38777 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -10,6 +10,7 @@ The rest are to define aliases only. """ +# mypy: ignore-errors from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Sequence, Union @@ -25,6 +26,7 @@ MultiTimeUnit_T, SingleTimeUnit_T, Type_T, + Map, ) from altair.vegalite.v5.schema.core import ( FieldEqualPredicate, @@ -36,12 +38,14 @@ FieldRangePredicate, FieldValidPredicate, ) +from altair.vegalite.v5.schema import channels if TYPE_CHECKING: from altair.utils.core import DataFrameLike from altair.vegalite.v5.schema._typing import AggregateOp_T from altair.vegalite.v5.schema.core import Predicate + __all__ = ["agg", "field"] EncodeType: TypeAlias = Union[Type_T, Literal["O", "N", "Q", "T", "G"], None] @@ -125,6 +129,138 @@ def _one_of_variance(val_1: Any, *rest: OneOfType) -> Sequence[Any]: raise TypeError(msg) +class _FieldMeta(type): + def __new__( # type: ignore[misc] + cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None + ) -> dict[str, Any]: + return _parse(shorthand=shorthand, data=data) + + def argmin( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("argmin", col_name, type) + + def argmax( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("argmax", col_name, type) + + def average( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("average", col_name, type) + + def count( + self, col_name: str | None = None, /, type: EncodeType = "Q" + ) -> dict[str, Any]: + return _parse_aggregate("count", col_name, type) + + def distinct( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("distinct", col_name, type) + + def max( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("max", col_name, type) + + def mean( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("mean", col_name, type) + + def median( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("median", col_name, type) + + def min( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("min", col_name, type) + + def missing( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("missing", col_name, type) + + def product( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("product", col_name, type) + + def q1( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("q1", col_name, type) + + def q3( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("q3", col_name, type) + + def ci0( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("ci0", col_name, type) + + def ci1( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("ci1", col_name, type) + + def stderr( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stderr", col_name, type) + + def stdev( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stdev", col_name, type) + + def stdevp( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("stdevp", col_name, type) + + def sum( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("sum", col_name, type) + + def valid( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("valid", col_name, type) + + def values( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("values", col_name, type) + + def variance( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("variance", col_name, type) + + def variancep( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("variancep", col_name, type) + + def exponential( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("exponential", col_name, type) + + def exponentialb( + self, col_name: str | None = None, /, type: EncodeType = None + ) -> dict[str, Any]: + return _parse_aggregate("exponentialb", col_name, type) + + class agg: """ Utility class providing autocomplete for shorthand. @@ -370,3 +506,119 @@ def range( ) -> SelectionPredicateComposition: p = FieldRangePredicate(field=field, range=value, timeUnit=timeUnit) return _wrap_composition(p) + + +class field_into: + def __init__(self, arg: Map, /) -> None: + self._arg: Map = arg + + def angle(self, *args: Any, **kwds: Any) -> channels.Angle: ... + def color(self, *args: Any, **kwds: Any) -> channels.Color: ... + def column(self, *args: Any, **kwds: Any) -> Any: ... + def description(self, *args: Any, **kwds: Any) -> Any: ... + def detail(self, *args: Any, **kwds: Any) -> Any: ... + def facet(self, *args: Any, **kwds: Any) -> Any: ... + def fill(self, *args: Any, **kwds: Any) -> Any: ... + def fill_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def href(self, *args: Any, **kwds: Any) -> Any: ... + def key(self, *args: Any, **kwds: Any) -> Any: ... + def latitude(self, *args: Any, **kwds: Any) -> Any: ... + def latitude2(self, *args: Any, **kwds: Any) -> Any: ... + def longitude(self, *args: Any, **kwds: Any) -> Any: ... + def longitude2(self, *args: Any, **kwds: Any) -> Any: ... + def opacity(self, *args: Any, **kwds: Any) -> Any: ... + def order(self, *args: Any, **kwds: Any) -> Any: ... + def radius(self, *args: Any, **kwds: Any) -> Any: ... + def radius2(self, *args: Any, **kwds: Any) -> Any: ... + def row(self, *args: Any, **kwds: Any) -> Any: ... + def shape(self, *args: Any, **kwds: Any) -> Any: ... + def size(self, *args: Any, **kwds: Any) -> Any: ... + def stroke(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_dash(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_width(self, *args: Any, **kwds: Any) -> Any: ... + def text(self, *args: Any, **kwds: Any) -> Any: ... + def theta(self, *args: Any, **kwds: Any) -> Any: ... + def theta2(self, *args: Any, **kwds: Any) -> Any: ... + def tooltip(self, *args: Any, **kwds: Any) -> Any: ... + def url(self, *args: Any, **kwds: Any) -> Any: ... + def x(self, *args: Any, **kwds: Any) -> channels.X: + return channels.X(*args, **self._arg, **kwds) + + def x2(self, *args: Any, **kwds: Any) -> channels.X2: + return channels.X2(*args, **self._arg, **kwds) + + def x_error(self, *args: Any, **kwds: Any) -> channels.XError: ... + def x_error2(self, *args: Any, **kwds: Any) -> channels.XError2: ... + def x_offset(self, *args: Any, **kwds: Any) -> channels.XOffset: ... + def y(self, *args: Any, **kwds: Any) -> channels.Y: + return channels.Y(*args, **self._arg, **kwds) + + def y2(self, *args: Any, **kwds: Any) -> channels.Y2: + return channels.Y2(*args, **self._arg, **kwds) + + def y_error(self, *args: Any, **kwds: Any) -> channels.YError: ... + def y_error2(self, *args: Any, **kwds: Any) -> channels.YError2: ... + def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffset: ... + @property + def value(self) -> _ChannelValueNamespace: + return _ChannelValueNamespace(self._arg) + + +class _ChannelValueNamespace: + def __init__(self, arg: Map, /) -> None: + self._arg: Map = arg + + def angle(self, *args: Any, **kwds: Any) -> Any: ... + def color(self, *args: Any, **kwds: Any) -> Any: ... + def description(self, *args: Any, **kwds: Any) -> Any: ... + def fill(self, *args: Any, **kwds: Any) -> Any: ... + def fill_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def href(self, *args: Any, **kwds: Any) -> Any: ... + def latitude2(self, *args: Any, **kwds: Any) -> Any: ... + def longitude2(self, *args: Any, **kwds: Any) -> Any: ... + def opacity(self, *args: Any, **kwds: Any) -> Any: ... + def order(self, *args: Any, **kwds: Any) -> Any: ... + def radius(self, *args: Any, **kwds: Any) -> Any: ... + def radius2(self, *args: Any, **kwds: Any) -> Any: ... + def shape(self, *args: Any, **kwds: Any) -> Any: ... + def size(self, *args: Any, **kwds: Any) -> Any: ... + def stroke(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_dash(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_opacity(self, *args: Any, **kwds: Any) -> Any: ... + def stroke_width(self, *args: Any, **kwds: Any) -> Any: ... + def text(self, *args: Any, **kwds: Any) -> Any: ... + def theta(self, *args: Any, **kwds: Any) -> Any: ... + def theta2(self, *args: Any, **kwds: Any) -> Any: ... + def tooltip(self, *args: Any, **kwds: Any) -> Any: ... + def url(self, *args: Any, **kwds: Any) -> Any: ... + def x(self, *args: Any, **kwds: Any) -> channels.XValue: + return channels.XValue(*args, **self._arg, **kwds) + + def x2(self, *args: Any, **kwds: Any) -> channels.X2Value: + return channels.X2Value(*args, **self._arg, **kwds) + + def x_error(self, *args: Any, **kwds: Any) -> channels.XErrorValue: ... + def x_error2(self, *args: Any, **kwds: Any) -> channels.XError2Value: ... + def x_offset(self, *args: Any, **kwds: Any) -> channels.XOffsetValue: ... + def y(self, *args: Any, **kwds: Any) -> channels.YValue: + return channels.YValue(*args, **self._arg, **kwds) + + def y2(self, *args: Any, **kwds: Any) -> channels.Y2Value: + return channels.Y2Value(*args, **self._arg, **kwds) + + def y_error(self, *args: Any, **kwds: Any) -> channels.YErrorValue: ... + def y_error2(self, *args: Any, **kwds: Any) -> channels.YError2Value: ... + def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffsetValue: ... + + +# field_out = agg.q1() +# wrapped = field_into(field_out).value.x() +## efg = field2.angle() +## some_field = field2.min("Cost", "Q") +# some_field = field_into(field("Cost:Q")) +# beeeee = some_field.x() +# cee = some_field.value.x() +# +# +# ffff = field_into(field("Cost:Q")).x2().bandPosition(4.7) From 83c3cfe6c925f3cf6e4e5374c11da90e472e18e7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:33:45 +0100 Subject: [PATCH 12/19] revert: Remove more of unfinished commit --- altair/vegalite/v5/_api_rfc.py | 135 +-------------------------------- 1 file changed, 3 insertions(+), 132 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 90fc38777..fee4c8c8e 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -129,138 +129,6 @@ def _one_of_variance(val_1: Any, *rest: OneOfType) -> Sequence[Any]: raise TypeError(msg) -class _FieldMeta(type): - def __new__( # type: ignore[misc] - cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None - ) -> dict[str, Any]: - return _parse(shorthand=shorthand, data=data) - - def argmin( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("argmin", col_name, type) - - def argmax( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("argmax", col_name, type) - - def average( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("average", col_name, type) - - def count( - self, col_name: str | None = None, /, type: EncodeType = "Q" - ) -> dict[str, Any]: - return _parse_aggregate("count", col_name, type) - - def distinct( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("distinct", col_name, type) - - def max( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("max", col_name, type) - - def mean( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("mean", col_name, type) - - def median( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("median", col_name, type) - - def min( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("min", col_name, type) - - def missing( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("missing", col_name, type) - - def product( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("product", col_name, type) - - def q1( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("q1", col_name, type) - - def q3( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("q3", col_name, type) - - def ci0( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("ci0", col_name, type) - - def ci1( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("ci1", col_name, type) - - def stderr( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("stderr", col_name, type) - - def stdev( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("stdev", col_name, type) - - def stdevp( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("stdevp", col_name, type) - - def sum( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("sum", col_name, type) - - def valid( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("valid", col_name, type) - - def values( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("values", col_name, type) - - def variance( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("variance", col_name, type) - - def variancep( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("variancep", col_name, type) - - def exponential( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("exponential", col_name, type) - - def exponentialb( - self, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: - return _parse_aggregate("exponentialb", col_name, type) - - class agg: """ Utility class providing autocomplete for shorthand. @@ -508,6 +376,9 @@ def range( return _wrap_composition(p) +# NOTE: Ignore everything below + + class field_into: def __init__(self, arg: Map, /) -> None: self._arg: Map = arg From 23405ba112a36735889811d0a3ab7eb08df8ebd5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:54:15 +0100 Subject: [PATCH 13/19] revert: Remove _ChannelValueNamespace --- altair/vegalite/v5/_api_rfc.py | 50 ---------------------------------- 1 file changed, 50 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index fee4c8c8e..5d4b5f438 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -431,56 +431,6 @@ def y2(self, *args: Any, **kwds: Any) -> channels.Y2: def y_error(self, *args: Any, **kwds: Any) -> channels.YError: ... def y_error2(self, *args: Any, **kwds: Any) -> channels.YError2: ... def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffset: ... - @property - def value(self) -> _ChannelValueNamespace: - return _ChannelValueNamespace(self._arg) - - -class _ChannelValueNamespace: - def __init__(self, arg: Map, /) -> None: - self._arg: Map = arg - - def angle(self, *args: Any, **kwds: Any) -> Any: ... - def color(self, *args: Any, **kwds: Any) -> Any: ... - def description(self, *args: Any, **kwds: Any) -> Any: ... - def fill(self, *args: Any, **kwds: Any) -> Any: ... - def fill_opacity(self, *args: Any, **kwds: Any) -> Any: ... - def href(self, *args: Any, **kwds: Any) -> Any: ... - def latitude2(self, *args: Any, **kwds: Any) -> Any: ... - def longitude2(self, *args: Any, **kwds: Any) -> Any: ... - def opacity(self, *args: Any, **kwds: Any) -> Any: ... - def order(self, *args: Any, **kwds: Any) -> Any: ... - def radius(self, *args: Any, **kwds: Any) -> Any: ... - def radius2(self, *args: Any, **kwds: Any) -> Any: ... - def shape(self, *args: Any, **kwds: Any) -> Any: ... - def size(self, *args: Any, **kwds: Any) -> Any: ... - def stroke(self, *args: Any, **kwds: Any) -> Any: ... - def stroke_dash(self, *args: Any, **kwds: Any) -> Any: ... - def stroke_opacity(self, *args: Any, **kwds: Any) -> Any: ... - def stroke_width(self, *args: Any, **kwds: Any) -> Any: ... - def text(self, *args: Any, **kwds: Any) -> Any: ... - def theta(self, *args: Any, **kwds: Any) -> Any: ... - def theta2(self, *args: Any, **kwds: Any) -> Any: ... - def tooltip(self, *args: Any, **kwds: Any) -> Any: ... - def url(self, *args: Any, **kwds: Any) -> Any: ... - def x(self, *args: Any, **kwds: Any) -> channels.XValue: - return channels.XValue(*args, **self._arg, **kwds) - - def x2(self, *args: Any, **kwds: Any) -> channels.X2Value: - return channels.X2Value(*args, **self._arg, **kwds) - - def x_error(self, *args: Any, **kwds: Any) -> channels.XErrorValue: ... - def x_error2(self, *args: Any, **kwds: Any) -> channels.XError2Value: ... - def x_offset(self, *args: Any, **kwds: Any) -> channels.XOffsetValue: ... - def y(self, *args: Any, **kwds: Any) -> channels.YValue: - return channels.YValue(*args, **self._arg, **kwds) - - def y2(self, *args: Any, **kwds: Any) -> channels.Y2Value: - return channels.Y2Value(*args, **self._arg, **kwds) - - def y_error(self, *args: Any, **kwds: Any) -> channels.YErrorValue: ... - def y_error2(self, *args: Any, **kwds: Any) -> channels.YError2Value: ... - def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffsetValue: ... # field_out = agg.q1() From d4cf757fe75cc3c59612cbf0438b6a90331ccdc2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:55:56 +0100 Subject: [PATCH 14/19] refactor(typing): Annotate `field_into` return types --- altair/vegalite/v5/_api_rfc.py | 56 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 5d4b5f438..cad6af094 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -385,34 +385,34 @@ def __init__(self, arg: Map, /) -> None: def angle(self, *args: Any, **kwds: Any) -> channels.Angle: ... def color(self, *args: Any, **kwds: Any) -> channels.Color: ... - def column(self, *args: Any, **kwds: Any) -> Any: ... - def description(self, *args: Any, **kwds: Any) -> Any: ... - def detail(self, *args: Any, **kwds: Any) -> Any: ... - def facet(self, *args: Any, **kwds: Any) -> Any: ... - def fill(self, *args: Any, **kwds: Any) -> Any: ... - def fill_opacity(self, *args: Any, **kwds: Any) -> Any: ... - def href(self, *args: Any, **kwds: Any) -> Any: ... - def key(self, *args: Any, **kwds: Any) -> Any: ... - def latitude(self, *args: Any, **kwds: Any) -> Any: ... - def latitude2(self, *args: Any, **kwds: Any) -> Any: ... - def longitude(self, *args: Any, **kwds: Any) -> Any: ... - def longitude2(self, *args: Any, **kwds: Any) -> Any: ... - def opacity(self, *args: Any, **kwds: Any) -> Any: ... - def order(self, *args: Any, **kwds: Any) -> Any: ... - def radius(self, *args: Any, **kwds: Any) -> Any: ... - def radius2(self, *args: Any, **kwds: Any) -> Any: ... - def row(self, *args: Any, **kwds: Any) -> Any: ... - def shape(self, *args: Any, **kwds: Any) -> Any: ... - def size(self, *args: Any, **kwds: Any) -> Any: ... - def stroke(self, *args: Any, **kwds: Any) -> Any: ... - def stroke_dash(self, *args: Any, **kwds: Any) -> Any: ... - def stroke_opacity(self, *args: Any, **kwds: Any) -> Any: ... - def stroke_width(self, *args: Any, **kwds: Any) -> Any: ... - def text(self, *args: Any, **kwds: Any) -> Any: ... - def theta(self, *args: Any, **kwds: Any) -> Any: ... - def theta2(self, *args: Any, **kwds: Any) -> Any: ... - def tooltip(self, *args: Any, **kwds: Any) -> Any: ... - def url(self, *args: Any, **kwds: Any) -> Any: ... + def column(self, *args: Any, **kwds: Any) -> channels.Column: ... + def description(self, *args: Any, **kwds: Any) -> channels.Description: ... + def detail(self, *args: Any, **kwds: Any) -> channels.Detail: ... + def facet(self, *args: Any, **kwds: Any) -> channels.Facet: ... + def fill(self, *args: Any, **kwds: Any) -> channels.Fill: ... + def fill_opacity(self, *args: Any, **kwds: Any) -> channels.FillOpacity: ... + def href(self, *args: Any, **kwds: Any) -> channels.Href: ... + def key(self, *args: Any, **kwds: Any) -> channels.Key: ... + def latitude(self, *args: Any, **kwds: Any) -> channels.Latitude: ... + def latitude2(self, *args: Any, **kwds: Any) -> channels.Latitude2: ... + def longitude(self, *args: Any, **kwds: Any) -> channels.Longitude: ... + def longitude2(self, *args: Any, **kwds: Any) -> channels.Longitude2: ... + def opacity(self, *args: Any, **kwds: Any) -> channels.Opacity: ... + def order(self, *args: Any, **kwds: Any) -> channels.Order: ... + def radius(self, *args: Any, **kwds: Any) -> channels.Radius: ... + def radius2(self, *args: Any, **kwds: Any) -> channels.Radius2: ... + def row(self, *args: Any, **kwds: Any) -> channels.Row: ... + def shape(self, *args: Any, **kwds: Any) -> channels.Shape: ... + def size(self, *args: Any, **kwds: Any) -> channels.Size: ... + def stroke(self, *args: Any, **kwds: Any) -> channels.Stroke: ... + def stroke_dash(self, *args: Any, **kwds: Any) -> channels.StrokeDash: ... + def stroke_opacity(self, *args: Any, **kwds: Any) -> channels.StrokeOpacity: ... + def stroke_width(self, *args: Any, **kwds: Any) -> channels.StrokeWidth: ... + def text(self, *args: Any, **kwds: Any) -> channels.Text: ... + def theta(self, *args: Any, **kwds: Any) -> channels.Theta: ... + def theta2(self, *args: Any, **kwds: Any) -> channels.Theta2: ... + def tooltip(self, *args: Any, **kwds: Any) -> channels.Tooltip: ... + def url(self, *args: Any, **kwds: Any) -> channels.Url: ... def x(self, *args: Any, **kwds: Any) -> channels.X: return channels.X(*args, **self._arg, **kwds) From e449796b22b72b39db8696cc41be9b8ce07130d0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 1 Aug 2024 19:07:53 +0100 Subject: [PATCH 15/19] chore: tidy up wip `field_into` --- altair/vegalite/v5/_api_rfc.py | 39 ++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index cad6af094..50b4aa404 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -376,10 +376,32 @@ def range( return _wrap_composition(p) -# NOTE: Ignore everything below +# NOTE: Ignore everything below # +# ----------------------------- # class field_into: + """ + Return wrapper for `agg`, `field` shorthand dicts. + + Idea + ---- + Rather than something like:: + + op_1 = alt.X(alt.agg.min("Cost", "Q")).scale(None) + + You could chain entirely from the agg:: + + # the internal unpacking will avoid double-checking the shorthand + op_2 = alt.agg.min("Cost", "Q").x().scale(None) + + Optionally, use the chained constructor:: + + op_2_1 = alt.agg.min("Cost", "Q").x(scale=None) + + + """ + def __init__(self, arg: Map, /) -> None: self._arg: Map = arg @@ -433,13 +455,8 @@ def y_error2(self, *args: Any, **kwds: Any) -> channels.YError2: ... def y_offset(self, *args: Any, **kwds: Any) -> channels.YOffset: ... -# field_out = agg.q1() -# wrapped = field_into(field_out).value.x() -## efg = field2.angle() -## some_field = field2.min("Cost", "Q") -# some_field = field_into(field("Cost:Q")) -# beeeee = some_field.x() -# cee = some_field.value.x() -# -# -# ffff = field_into(field("Cost:Q")).x2().bandPosition(4.7) +def example_field(): + field_out = agg.q1() + wrapped = field_into(field_out).x() # noqa: F841 + some_field = field_into(agg.min("Cost", "Q")) + beeeee = some_field.x().scale(None).impute(None).axis(None) # noqa: F841 From a8e1ea14475b6a787f79467942451edc633ff9e0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:14:45 +0000 Subject: [PATCH 16/19] refactor(ruff): Lint for `3.9` --- altair/vegalite/v5/_api_rfc.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 62d73584b..ea7173016 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -13,7 +13,8 @@ # mypy: ignore-errors from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Sequence, Union +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, Union from typing_extensions import TypeAlias from altair.utils.core import TYPECODE_MAP as _TYPE_CODE @@ -49,15 +50,15 @@ EncodeType: TypeAlias = Union[Type_T, Literal["O", "N", "Q", "T", "G"], None] AnyTimeUnit: TypeAlias = Union[MultiTimeUnit_T, BinnedTimeUnit_T, SingleTimeUnit_T] -TimeUnitType: TypeAlias = Optional[Union[Dict[str, Any], SchemaBase, AnyTimeUnit]] +TimeUnitType: TypeAlias = Optional[Union[dict[str, Any], SchemaBase, AnyTimeUnit]] RangeType: TypeAlias = Union[ - Dict[str, Any], + dict[str, Any], Parameter, SchemaBase, - Sequence[Union[Dict[str, Any], None, float, Parameter, SchemaBase]], + Sequence[Union[dict[str, Any], None, float, Parameter, SchemaBase]], ] -ValueType: TypeAlias = Union[str, bool, float, Dict[str, Any], Parameter, SchemaBase] -OneOfType: TypeAlias = Union[str, bool, float, Dict[str, Any], SchemaBase] +ValueType: TypeAlias = Union[str, bool, float, dict[str, Any], Parameter, SchemaBase] +OneOfType: TypeAlias = Union[str, bool, float, dict[str, Any], SchemaBase] _ENCODINGS = frozenset( From f890feb4ad3dfe3df6d3d8fe395b1ca99b4989fd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:31:38 +0000 Subject: [PATCH 17/19] refactor: Factor out `SelectionPredicateComposition` Utilizes #3668 --- altair/vegalite/v5/_api_rfc.py | 51 ++++++++++++++-------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index ea7173016..81a7b46a5 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -20,7 +20,7 @@ from altair.utils.core import TYPECODE_MAP as _TYPE_CODE from altair.utils.core import parse_shorthand as _parse from altair.utils.schemapi import Optional, SchemaBase, Undefined -from altair.vegalite.v5.api import Parameter, SelectionPredicateComposition +from altair.vegalite.v5.api import Parameter from altair.vegalite.v5.schema import channels from altair.vegalite.v5.schema._typing import ( BinnedTimeUnit_T, @@ -93,10 +93,6 @@ def _parse_aggregate( raise TypeError(msg) -def _wrap_composition(predicate: Predicate, /) -> SelectionPredicateComposition: - return SelectionPredicateComposition(predicate.to_dict()) - - def _one_of_flatten( values: tuple[OneOfType, ...] | tuple[Sequence[OneOfType]] | tuple[Any, ...], / ) -> Sequence[OneOfType]: @@ -305,7 +301,10 @@ class field: {'field': 'Origin', 'type': 'nominal'} >>> field.one_of("Origin", "Japan", "Europe") - SelectionPredicateComposition({'field': 'Origin', 'oneOf': ['Japan', 'Europe']}) + FieldOneOfPredicate({ + field: 'Origin', + oneOf: ('Japan', 'Europe') + }) """ def __new__( # type: ignore[misc] @@ -320,60 +319,52 @@ def one_of( /, *values: OneOfType | Sequence[OneOfType], timeUnit: TimeUnitType = Undefined, - ) -> SelectionPredicateComposition: + ) -> Predicate: seq = _one_of_flatten(values) one_of = _one_of_variance(*seq) - p = FieldOneOfPredicate(field=field, oneOf=one_of, timeUnit=timeUnit) - return _wrap_composition(p) + return FieldOneOfPredicate(field=field, oneOf=one_of, timeUnit=timeUnit) @classmethod def eq( cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined - ) -> SelectionPredicateComposition: - p = FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit) - return _wrap_composition(p) + ) -> Predicate: + return FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit) @classmethod def lt( cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined - ) -> SelectionPredicateComposition: - p = FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit) - return _wrap_composition(p) + ) -> Predicate: + return FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit) @classmethod def lte( cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined - ) -> SelectionPredicateComposition: - p = FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit) - return _wrap_composition(p) + ) -> Predicate: + return FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit) @classmethod def gt( cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined - ) -> SelectionPredicateComposition: - p = FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit) - return _wrap_composition(p) + ) -> Predicate: + return FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit) @classmethod def gte( cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined - ) -> SelectionPredicateComposition: - p = FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit) - return _wrap_composition(p) + ) -> Predicate: + return FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit) @classmethod def valid( cls, field: str, value: bool, /, *, timeUnit: TimeUnitType = Undefined - ) -> SelectionPredicateComposition: - p = FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit) - return _wrap_composition(p) + ) -> Predicate: + return FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit) @classmethod def range( cls, field: str, value: RangeType, /, *, timeUnit: TimeUnitType = Undefined - ) -> SelectionPredicateComposition: - p = FieldRangePredicate(field=field, range=value, timeUnit=timeUnit) - return _wrap_composition(p) + ) -> Predicate: + return FieldRangePredicate(field=field, range=value, timeUnit=timeUnit) # NOTE: Ignore everything below # From 106d76f08640860b7af7fe3933f989cf334285f4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:07:48 +0000 Subject: [PATCH 18/19] refactor(typing): Update & rename aliases - When I originally wrote this, not all of these were available - Overall, trying to reduce verbosity --- altair/vegalite/v5/_api_rfc.py | 162 +++++++++++++---------------- tests/vegalite/v5/test__api_rfc.py | 6 +- 2 files changed, 77 insertions(+), 91 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index 81a7b46a5..ea4ac1277 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -13,22 +13,14 @@ # mypy: ignore-errors from __future__ import annotations +import datetime as dt from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, Union -from typing_extensions import TypeAlias +from typing import TYPE_CHECKING from altair.utils.core import TYPECODE_MAP as _TYPE_CODE from altair.utils.core import parse_shorthand as _parse -from altair.utils.schemapi import Optional, SchemaBase, Undefined -from altair.vegalite.v5.api import Parameter +from altair.utils.schemapi import SchemaBase, Undefined from altair.vegalite.v5.schema import channels -from altair.vegalite.v5.schema._typing import ( - BinnedTimeUnit_T, - Map, - MultiTimeUnit_T, - SingleTimeUnit_T, - Type_T, -) from altair.vegalite.v5.schema.core import ( FieldEqualPredicate, FieldGTEPredicate, @@ -41,24 +33,37 @@ ) if TYPE_CHECKING: - from altair.utils.core import DataFrameLike - from altair.vegalite.v5.schema._typing import AggregateOp_T - from altair.vegalite.v5.schema.core import Predicate + import sys + from typing import Any, Literal + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + from narwhals.typing import IntoDataFrame + + from altair.utils.schemapi import Optional + from altair.vegalite.v5.api import Parameter, _FieldEqualType + from altair.vegalite.v5.schema._typing import ( + AggregateOp_T, + BinnedTimeUnit_T, + Map, + MultiTimeUnit_T, + OneOrSeq, + SingleTimeUnit_T, + Temporal, + Type_T, + ) + from altair.vegalite.v5.schema.core import Predicate, TimeUnitParams __all__ = ["agg", "field"] -EncodeType: TypeAlias = Union[Type_T, Literal["O", "N", "Q", "T", "G"], None] -AnyTimeUnit: TypeAlias = Union[MultiTimeUnit_T, BinnedTimeUnit_T, SingleTimeUnit_T] -TimeUnitType: TypeAlias = Optional[Union[dict[str, Any], SchemaBase, AnyTimeUnit]] -RangeType: TypeAlias = Union[ - dict[str, Any], - Parameter, - SchemaBase, - Sequence[Union[dict[str, Any], None, float, Parameter, SchemaBase]], -] -ValueType: TypeAlias = Union[str, bool, float, dict[str, Any], Parameter, SchemaBase] -OneOfType: TypeAlias = Union[str, bool, float, dict[str, Any], SchemaBase] +_Type: TypeAlias = 'Type_T | Literal["O", "N", "Q", "T", "G"] | None' +_TimeUnit: TypeAlias = "Optional[TimeUnitParams | SingleTimeUnit_T | MultiTimeUnit_T | BinnedTimeUnit_T | Map]" +_Range: TypeAlias = "Parameter | SchemaBase | Sequence[float | Temporal | Parameter | SchemaBase | Map | None] | Map" +_Value: TypeAlias = "str | float | Temporal | Parameter | SchemaBase | Map" +_OneOf: TypeAlias = "str | bool | float | Temporal | SchemaBase | Map" _ENCODINGS = frozenset( @@ -79,7 +84,7 @@ def _parse_aggregate( - aggregate: AggregateOp_T, name: str | None, encode_type: EncodeType, / + aggregate: AggregateOp_T, name: str | None, encode_type: _Type, / ) -> dict[str, Any]: if encode_type in _ENCODINGS: enc = f":{_TYPE_CODE.get(s, s)}" if (s := encode_type) else "" @@ -94,11 +99,13 @@ def _parse_aggregate( def _one_of_flatten( - values: tuple[OneOfType, ...] | tuple[Sequence[OneOfType]] | tuple[Any, ...], / -) -> Sequence[OneOfType]: + values: tuple[_OneOf, ...] | tuple[Sequence[_OneOf]] | tuple[Any, ...], / +) -> Sequence[_OneOf]: if ( len(values) == 1 - and not isinstance(values[0], (str, bool, float, int, Mapping, SchemaBase)) + and not isinstance( + values[0], (str, bool, float, int, dt.date, Mapping, SchemaBase) + ) and isinstance(values[0], Sequence) ): return values[0] @@ -112,7 +119,7 @@ def _one_of_flatten( raise TypeError(msg) -def _one_of_variance(val_1: Any, *rest: OneOfType) -> Sequence[Any]: +def _one_of_variance(val_1: Any, *rest: _OneOf) -> Sequence[Any]: # Required that all elements are the same type tp = type(val_1) if all(isinstance(v, tp) for v in rest): @@ -133,157 +140,139 @@ class agg: """ def __new__( # type: ignore[misc] - cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None + cls, shorthand: dict[str, Any] | str, /, data: IntoDataFrame | None = None ) -> dict[str, Any]: return _parse(shorthand=shorthand, data=data) @classmethod def argmin( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("argmin", col_name, type) @classmethod def argmax( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("argmax", col_name, type) @classmethod def average( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("average", col_name, type) @classmethod - def count( - cls, col_name: str | None = None, /, type: EncodeType = "Q" - ) -> dict[str, Any]: + def count(cls, col_name: str | None = None, /, type: _Type = "Q") -> dict[str, Any]: return _parse_aggregate("count", col_name, type) @classmethod def distinct( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("distinct", col_name, type) @classmethod - def max( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def max(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("max", col_name, type) @classmethod - def mean( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def mean(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("mean", col_name, type) @classmethod def median( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("median", col_name, type) @classmethod - def min( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def min(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("min", col_name, type) @classmethod def missing( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("missing", col_name, type) @classmethod def product( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("product", col_name, type) @classmethod - def q1( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def q1(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("q1", col_name, type) @classmethod - def q3( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def q3(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("q3", col_name, type) @classmethod - def ci0( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def ci0(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("ci0", col_name, type) @classmethod - def ci1( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def ci1(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("ci1", col_name, type) @classmethod def stderr( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("stderr", col_name, type) @classmethod def stdev( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("stdev", col_name, type) @classmethod def stdevp( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("stdevp", col_name, type) @classmethod - def sum( - cls, col_name: str | None = None, /, type: EncodeType = None - ) -> dict[str, Any]: + def sum(cls, col_name: str | None = None, /, type: _Type = None) -> dict[str, Any]: return _parse_aggregate("sum", col_name, type) @classmethod def valid( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("valid", col_name, type) @classmethod def values( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("values", col_name, type) @classmethod def variance( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("variance", col_name, type) @classmethod def variancep( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("variancep", col_name, type) @classmethod def exponential( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("exponential", col_name, type) @classmethod def exponentialb( - cls, col_name: str | None = None, /, type: EncodeType = None + cls, col_name: str | None = None, /, type: _Type = None ) -> dict[str, Any]: return _parse_aggregate("exponentialb", col_name, type) @@ -308,17 +297,13 @@ class field: """ def __new__( # type: ignore[misc] - cls, shorthand: dict[str, Any] | str, /, data: DataFrameLike | None = None + cls, shorthand: dict[str, Any] | str, /, data: IntoDataFrame | None = None ) -> dict[str, Any]: return _parse(shorthand=shorthand, data=data) @classmethod def one_of( - cls, - field: str, - /, - *values: OneOfType | Sequence[OneOfType], - timeUnit: TimeUnitType = Undefined, + cls, field: str, /, *values: OneOrSeq[_OneOf], timeUnit: _TimeUnit = Undefined ) -> Predicate: seq = _one_of_flatten(values) one_of = _one_of_variance(*seq) @@ -326,43 +311,46 @@ def one_of( @classmethod def eq( - cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + cls, field: str, value: _FieldEqualType, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: + if value is None: + # NOTE: Unclear why this is allowed for `datum` but not in `FieldEqualPredicate` + raise TypeError(value) return FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit) @classmethod def lt( - cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: return FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit) @classmethod def lte( - cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: return FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit) @classmethod def gt( - cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: return FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit) @classmethod def gte( - cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined + cls, field: str, value: _Value, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: return FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit) @classmethod def valid( - cls, field: str, value: bool, /, *, timeUnit: TimeUnitType = Undefined + cls, field: str, value: bool, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: return FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit) @classmethod def range( - cls, field: str, value: RangeType, /, *, timeUnit: TimeUnitType = Undefined + cls, field: str, value: _Range, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: return FieldRangePredicate(field=field, range=value, timeUnit=timeUnit) diff --git a/tests/vegalite/v5/test__api_rfc.py b/tests/vegalite/v5/test__api_rfc.py index aa2fe6297..de4f7f9af 100644 --- a/tests/vegalite/v5/test__api_rfc.py +++ b/tests/vegalite/v5/test__api_rfc.py @@ -8,7 +8,7 @@ import altair as alt from altair.utils.core import INV_TYPECODE_MAP, TYPECODE_MAP -from altair.vegalite.v5._api_rfc import EncodeType, agg, field +from altair.vegalite.v5._api_rfc import _Type, agg, field if TYPE_CHECKING: from altair.vegalite.v5.schema._typing import AggregateOp_T @@ -68,9 +68,7 @@ def test_agg_type_invalid() -> None: "exponentialb", ], ) -def test_agg_methods( - method_name: AggregateOp_T, col_name: str | None, enc_type: EncodeType -): +def test_agg_methods(method_name: AggregateOp_T, col_name: str | None, enc_type: _Type): actual = getattr(agg, method_name)(col_name, enc_type) assert isinstance(actual, dict) assert actual["aggregate"] == method_name From 0e806ce55db1acd7e8591c7eb0c30101baf26d2b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:12:21 +0000 Subject: [PATCH 19/19] feat: Permit `field.eq("field", None)` Simply rewrites as `FieldValidPredicate("field", valid=False)` > If set to true the field's value has to be valid, meaning both not ``null`` and not NaN --- altair/vegalite/v5/_api_rfc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/altair/vegalite/v5/_api_rfc.py b/altair/vegalite/v5/_api_rfc.py index ea4ac1277..6a100387c 100644 --- a/altair/vegalite/v5/_api_rfc.py +++ b/altair/vegalite/v5/_api_rfc.py @@ -314,8 +314,7 @@ def eq( cls, field: str, value: _FieldEqualType, /, *, timeUnit: _TimeUnit = Undefined ) -> Predicate: if value is None: - # NOTE: Unclear why this is allowed for `datum` but not in `FieldEqualPredicate` - raise TypeError(value) + return cls.valid(field, False, timeUnit=timeUnit) return FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit) @classmethod