diff --git a/altair/__init__.py b/altair/__init__.py index 073a1a53e..55feb3cdb 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -56,6 +56,7 @@ "Categorical", "Chart", "ChartDataType", + "ChartType", "Color", "ColorDatum", "ColorDef", @@ -580,6 +581,7 @@ "expr", "graticule", "hconcat", + "is_chart_type", "jupyter", "layer", "limit_rows", diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 1d886eb5e..7a616b9c5 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -180,8 +180,8 @@ def name_views( List of the names of the charts and subcharts """ exclude = set(exclude) if exclude is not None else set() - if isinstance(chart, _chart_class_mapping[Chart]) or isinstance( - chart, _chart_class_mapping[FacetChart] + if isinstance( + chart, (_chart_class_mapping[Chart], _chart_class_mapping[FacetChart]) ): if chart.name not in exclude: if chart.name in (None, Undefined): diff --git a/altair/utils/save.py b/altair/utils/save.py index 11db77bf1..609486bc6 100644 --- a/altair/utils/save.py +++ b/altair/utils/save.py @@ -10,14 +10,14 @@ def write_file_or_filename( - fp: Union[str, pathlib.PurePath, IO], + fp: Union[str, pathlib.Path, IO], content: Union[str, bytes], mode: str = "w", encoding: Optional[str] = None, ) -> None: """Write content to fp, whether fp is a string, a pathlib Path or a file-like object""" - if isinstance(fp, str) or isinstance(fp, pathlib.PurePath): + if isinstance(fp, (str, pathlib.Path)): with open(file=fp, mode=mode, encoding=encoding) as f: f.write(content) else: @@ -25,14 +25,12 @@ def write_file_or_filename( def set_inspect_format_argument( - format: Optional[str], fp: Union[str, pathlib.PurePath, IO], inline: bool + format: Optional[str], fp: Union[str, pathlib.Path, IO], inline: bool ) -> str: """Inspect the format argument in the save function""" if format is None: - if isinstance(fp, str): - format = fp.split(".")[-1] - elif isinstance(fp, pathlib.PurePath): - format = fp.suffix.lstrip(".") + if isinstance(fp, (str, pathlib.Path)): + format = pathlib.Path(fp).suffix.lstrip(".") else: raise ValueError( "must specify file format: " @@ -71,7 +69,7 @@ def set_inspect_mode_argument( def save( chart, - fp: Union[str, pathlib.PurePath, IO], + fp: Union[str, pathlib.Path, IO], vega_version: Optional[str], vegaembed_version: Optional[str], format: Optional[Literal["json", "html", "png", "svg", "pdf"]] = None, @@ -140,7 +138,7 @@ def save( if json_kwds is None: json_kwds = {} - + encoding = kwargs.get("encoding", "utf-8") format = set_inspect_format_argument(format, fp, inline) # type: ignore[assignment] def perform_save(): @@ -152,9 +150,7 @@ def perform_save(): if format == "json": json_spec = json.dumps(spec, **json_kwds) - write_file_or_filename( - fp, json_spec, mode="w", encoding=kwargs.get("encoding", "utf-8") - ) + write_file_or_filename(fp, json_spec, mode="w", encoding=encoding) elif format == "html": if inline: kwargs["template"] = "inline" @@ -170,10 +166,7 @@ def perform_save(): **kwargs, ) write_file_or_filename( - fp, - mimebundle["text/html"], - mode="w", - encoding=kwargs.get("encoding", "utf-8"), + fp, mimebundle["text/html"], mode="w", encoding=encoding ) elif format in ["png", "svg", "pdf", "vega"]: mimebundle = spec_to_mimebundle( @@ -193,7 +186,6 @@ def perform_save(): elif format == "pdf": write_file_or_filename(fp, mimebundle["application/pdf"], mode="wb") else: - encoding = kwargs.get("encoding", "utf-8") write_file_or_filename( fp, mimebundle["image/svg+xml"], mode="w", encoding=encoding ) diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index 9325b3a70..10b2597e6 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -1193,7 +1193,7 @@ def _freeze(val): return frozenset((k, _freeze(v)) for k, v in val.items()) elif isinstance(val, set): return frozenset(map(_freeze, val)) - elif isinstance(val, list) or isinstance(val, tuple): + elif isinstance(val, (list, tuple)): return tuple(map(_freeze, val)) else: return val diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 0bd9ae389..37a13c2ae 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,6 +8,8 @@ from toolz.curried import pipe as _pipe import itertools import sys +import pathlib +import typing from typing import cast, List, Optional, Any, Iterable, Union, Literal, IO # Have to rename it here as else it overlaps with schema.core.Type and schema.core.Dict @@ -30,6 +32,10 @@ from ...utils.data import DataType from ...utils.deprecation import AltairDeprecationWarning +if sys.version_info >= (3, 13): + from typing import TypeIs +else: + from typing_extensions import TypeIs if sys.version_info >= (3, 11): from typing import Self else: @@ -1138,7 +1144,7 @@ def open_editor(self, *, fullscreen: bool = False) -> None: def save( self, - fp: Union[str, IO], + fp: Union[str, pathlib.Path, IO], format: Optional[Literal["json", "html", "png", "svg", "pdf"]] = None, override_data_transformer: bool = True, scale_factor: float = 1.0, @@ -4098,3 +4104,18 @@ def graticule(**kwds): def sphere() -> core.SphereGenerator: """Sphere generator.""" return core.SphereGenerator(sphere=True) + + +ChartType = Union[ + Chart, RepeatChart, ConcatChart, HConcatChart, VConcatChart, FacetChart, LayerChart +] + + +def is_chart_type(obj: Any) -> TypeIs[ChartType]: + """Return `True` if the object is an Altair chart. This can be a basic chart + but also a repeat, concat, or facet chart. + """ + return isinstance( + obj, + typing.get_args(ChartType), + ) diff --git a/doc/user_guide/api.rst b/doc/user_guide/api.rst index 99aa5f69b..08013023f 100644 --- a/doc/user_guide/api.rst +++ b/doc/user_guide/api.rst @@ -152,6 +152,7 @@ API Functions condition graticule hconcat + is_chart_type layer param repeat diff --git a/pyproject.toml b/pyproject.toml index 36c95b317..a7ad687f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ build-backend = "hatchling.build" name = "altair" authors = [{ name = "Vega-Altair Contributors" }] dependencies = [ - "typing_extensions>=4.0.1; python_version<\"3.11\"", + "typing_extensions>=4.10.0; python_version<\"3.13\"", "jinja2", # If you update the minimum required jsonschema version, also update it in build.yml "jsonschema>=3.0", @@ -164,6 +164,7 @@ exclude = [ ] [tool.ruff.lint] +extend-safe-fixes=["SIM101"] select = [ # flake8-bugbear "B", @@ -177,6 +178,8 @@ select = [ "F", # flake8-tidy-imports "TID", + # flake8-simplify + "SIM101" ] ignore = [ # Whitespace before ':' diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 29da970ec..ce1268c04 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -1191,7 +1191,7 @@ def _freeze(val): return frozenset((k, _freeze(v)) for k, v in val.items()) elif isinstance(val, set): return frozenset(map(_freeze, val)) - elif isinstance(val, list) or isinstance(val, tuple): + elif isinstance(val, (list, tuple)): return tuple(map(_freeze, val)) else: return val diff --git a/tools/update_init_file.py b/tools/update_init_file.py index 4f342f8ef..e4fa65a86 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -21,6 +21,10 @@ cast, ) +if sys.version_info >= (3, 13): + from typing import TypeIs +else: + from typing_extensions import TypeIs if sys.version_info >= (3, 11): from typing import Self else: @@ -97,6 +101,7 @@ def _is_relevant_attribute(attr_name: str) -> bool: or attr is Protocol or attr is Sequence or attr is IO + or attr is TypeIs or attr_name == "TypingDict" or attr_name == "TypingGenerator" or attr_name == "ValueOrDatum"