From 06152280cc4e993eb12bbcc2b604af247c56ae21 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 16 Oct 2024 18:47:07 +0100 Subject: [PATCH] refactor: Adds `tools.codemod.Ruff` https://github.com/vega/altair/pull/3630#discussion_r1801329937 **Introduces no new behavior**. --- sphinxext/code_ref.py | 8 +- tools/codemod.py | 335 +++++++++++++++++++++++++------ tools/generate_schema_wrapper.py | 9 +- tools/schemapi/utils.py | 78 +------ tools/update_init_file.py | 4 +- tools/vega_expr.py | 6 +- 6 files changed, 287 insertions(+), 153 deletions(-) diff --git a/sphinxext/code_ref.py b/sphinxext/code_ref.py index c0a1d4636..6371713de 100644 --- a/sphinxext/code_ref.py +++ b/sphinxext/code_ref.py @@ -10,7 +10,7 @@ from sphinx.util.parsing import nested_parse_to_nodes from altair.vegalite.v5.schema._typing import VegaThemes -from tools.codemod import embed_extract_func_def, extract_func_def +from tools.codemod import extract_func_def, extract_func_def_embed if TYPE_CHECKING: import sys @@ -222,11 +222,11 @@ def run(self) -> Sequence[nodes.Node]: optgroup("Carbon", (option(nm) for nm in carbon_names)), ) ) - py_code = embed_extract_func_def( + py_code = extract_func_def_embed( module_name, func_name, - before_code=_before_code(REFRESH_NAME, SELECT_ID, TARGET_DIV_ID), - after_code=f"{REFRESH_NAME}()", + before=_before_code(REFRESH_NAME, SELECT_ID, TARGET_DIV_ID), + after=f"{REFRESH_NAME}()", assign_to="chart", indent=4, ) diff --git a/tools/codemod.py b/tools/codemod.py index 52261790c..8281e0732 100644 --- a/tools/codemod.py +++ b/tools/codemod.py @@ -1,21 +1,39 @@ +# ruff: noqa: D418 from __future__ import annotations import ast import subprocess import sys import textwrap +import warnings +from collections import deque from importlib.util import find_spec from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable, TypeVar, overload + +if sys.version_info >= (3, 12): + from typing import Protocol, TypeAliasType +else: + from typing_extensions import Protocol, TypeAliasType if TYPE_CHECKING: - if sys.version_info >= (3, 10): - from typing import TypeAlias + if sys.version_info >= (3, 11): + from typing import LiteralString else: - from typing_extensions import TypeAlias - from typing import Iterable, Iterator, Literal + from typing_extensions import LiteralString + + from typing import ClassVar, Iterator, Literal - _Code: TypeAlias = "str | Iterable[str]" + +__all__ = ["extract_func_def", "extract_func_def_embed", "ruff", "ruff_inline_docs"] + +T = TypeVar("T") +OneOrIterV = TypeAliasType( + "OneOrIterV", + "T | Iterable[T] | Iterable[OneOrIterV[T]]", + type_params=(T,), +) +_Code = OneOrIterV[str] def parse_module(name: str, /) -> ast.Module: @@ -52,8 +70,6 @@ def unparse(obj: ast.AST, /) -> str: """ # HACK: Will only be used during build/docs # - This branch is just to satisfy linters - import warnings - msg = f"Called `ast.unparse()` on {sys.version_info!r}\nFunction not available before {(3, 9)!r}" warnings.warn(msg, ImportWarning, stacklevel=2) return "" @@ -108,10 +124,19 @@ def validate_body(fn: ast.FunctionDef, /) -> tuple[list[ast.stmt], ast.expr]: return body, last.value -def normalize_code(code: _Code, /) -> str: - if not isinstance(code, str): - code = "\n".join(code) - return code +def iter_flatten(*elements: _Code) -> Iterator[str]: + for el in elements: + if not isinstance(el, str) and isinstance(el, Iterable): + yield from iter_flatten(*el) + elif isinstance(el, str): + yield el + else: + msg = ( + f"Expected all elements to eventually flatten to ``str``, " + f"but got: {type(el).__name__!r}\n\n" + f"{el!r}" + ) + raise TypeError(msg) def iter_func_def_unparse( @@ -138,32 +163,8 @@ def iter_func_def_unparse( elif return_transform == "assign": yield f"{assign_to} = {ret_value}" else: - raise TypeError(return_transform) - - -def embed_extract_func_def( - module_name: str, - func_name: str, - /, - before_code: _Code | None = None, - after_code: _Code | None = None, - assign_to: str = "chart", - indent: int | None = None, -): - parts: list[_Code] = [] - if before_code is not None: - parts.append(before_code) - parts.append( - iter_func_def_unparse( - module_name, func_name, return_transform="assign", assign_to=assign_to - ) - ) - if after_code is not None: - parts.append(after_code) - normed = "\n".join(normalize_code(s) for s in parts) - checked = ruff_check_str(normed) - formatted = ruff_format_str(checked) - return textwrap.indent(formatted, " " * indent) if indent else formatted + msg = f"{return_transform=}" + raise NotImplementedError(msg) def extract_func_def( @@ -208,39 +209,245 @@ def extract_func_def( if output not in {"altair-plot", "code-block", "str"}: raise TypeError(output) it = iter_func_def_unparse(module_name, func_name) - s = ruff_format_str(it, trailing_comma=False) if format else "\n".join(it) + s = ruff_inline_docs.format(it) if format else "\n".join(it) if output == "str": return s else: return f".. {output}::\n\n{textwrap.indent(s, ' ' * 4)}\n" -# TODO: Move these and `tools.schemapi.utils.ruff_` into a `Ruff` class -# - Then can support configuring an instance like the `mistune` classes -def ruff_check_str(code: _Code, /) -> str: - encoded = normalize_code(code).encode() - cmd = [ - "ruff", - "check", - "--fix", - "--ignore", - "E711", # Comparison to `None` +def extract_func_def_embed( + module_name: str, + func_name: str, + /, + before: _Code | None = None, + after: _Code | None = None, + assign_to: str = "chart", + indent: int | None = None, +) -> str: + """ + Extract the contents of a function, wrapping with ``before`` and ``after``. + + The resulting code block is run through ``ruff`` to deduplicate imports + and apply consistent formatting. + + Parameters + ---------- + module_name + Absolute, dotted import style. + func_name + Name of function in ``module_name``. + before + Code inserted before ``func_name``. + after + Code inserted after ``func_name``. + assign_to + Variable name to use as the result of ``func_name``. + + .. note:: + Allows the ``after`` block to use a consistent reference. + indent + Optionally, prefix ``indent * " "`` to final block. + + .. note:: + Occurs **after** formatting, will not contribute to line length wrap. + """ + if before is None and after is None: + msg = ( + f"At least one additional code fragment should be provided, but:\n" + f"{before=}, {after=}\n\n" + f"Use {extract_func_def.__qualname__!r} instead." + ) + warnings.warn(msg, UserWarning, stacklevel=2) + unparsed = iter_func_def_unparse( + module_name, func_name, return_transform="assign", assign_to=assign_to + ) + parts = [p for p in (before, unparsed, after) if p is not None] + formatted = ruff_inline_docs(parts) + return textwrap.indent(formatted, " " * indent) if indent else formatted + + +class CodeMod(Protocol): + def __call__(self, *code: _Code) -> str: + """ + Transform some input into a single block of modified code. + + Parameters + ---------- + *code + Arbitrarily nested code fragments. + """ + ... + + def _join(self, code: _Code, *, sep: str = "\n") -> str: + """ + Concatenate any number of code fragments. + + All nested groups are unwrapped into a flat iterable. + """ + return sep.join(iter_flatten(code)) + + +class Ruff(CodeMod): + """ + Run `ruff`_ commands against code fragments or files. + + By default, uses the same config as `pyproject.toml`_. + + Parameters + ---------- + *extend_select + `rule codes`_ to use **on top of** the default config. + ignore + `rule codes`_ to `ignore`_. + skip_magic_traling_comma + Enables `skip-magic-trailing-comma`_ during formatting. + + .. note:: + + Only use on code that is changing indent-level + (e.g. unwrapping function contents). + + .. _ruff: + https://docs.astral.sh/ruff/ + .. _pyproject.toml: + https://github.com/vega/altair/blob/main/pyproject.toml + .. _rule codes: + https://docs.astral.sh/ruff/rules/ + .. _ignore: + https://docs.astral.sh/ruff/settings/#lint_ignore + .. _skip-magic-trailing-comma: + https://docs.astral.sh/ruff/settings/#format_skip-magic-trailing-comma + """ + + _stdin_args: ClassVar[tuple[LiteralString, ...]] = ( "--stdin-filename", "placeholder.py", - ] - capture_output = True - r = subprocess.run(cmd, input=encoded, check=True, capture_output=capture_output) - return r.stdout.decode() + ) + _check_args: ClassVar[tuple[LiteralString, ...]] = ("--fix",) + + def __init__( + self, + *extend_select: str, + ignore: OneOrIterV[str] | None = None, + skip_magic_traling_comma: bool = False, + ) -> None: + self.check_args: deque[str] = deque(self._check_args) + self.format_args: deque[str] = deque() + for c in extend_select: + self.check_args.extend(("--extend-select", c)) + if ignore is not None: + self.check_args.extend( + ("--ignore", ",".join(s for s in iter_flatten(ignore))) + ) + if skip_magic_traling_comma: + self.format_args.extend( + ("--config", "format.skip-magic-trailing-comma = true") + ) + + def write_lint_format(self, fp: Path, code: _Code, /) -> None: + """ + Combined steps of writing, `ruff check`, `ruff format`. + + Parameters + ---------- + fp + Target file to write to + code + Some (potentially) nested code fragments. + + Notes + ----- + - `fp` is written to first, as the size before formatting will be the smallest + - Better utilizes `ruff` performance, rather than `python` str and io + """ + self.check(fp, code) + self.format(fp) + @overload + def check(self, *code: _Code, decode: Literal[True] = ...) -> str: + """Fixes violations and returns fixed code.""" -def ruff_format_str( - code: str | Iterable[str], /, *, trailing_comma: bool = True -) -> str: - # NOTE: Brought this back w/ changes after removing in #3536 - encoded = normalize_code(code).encode() - cmd = ["ruff", "format", "--stdin-filename", "placeholder.py"] - if not trailing_comma: - cmd.extend(("--config", "format.skip-magic-trailing-comma = true")) - - r = subprocess.run(cmd, input=encoded, check=True, capture_output=True) - return r.stdout.decode() + @overload + def check(self, *code: _Code, decode: Literal[False]) -> bytes: + """ + ``decode=False`` will return as ``bytes``. + + Helpful if piping to another command. + """ + + @overload + def check(self, _write_to: Path, /, *code: _Code) -> None: + """ + ``code`` is joined, written to provided path and then checked. + + No input returned. + """ + + def check(self, *code: Any, decode: bool = True) -> str | bytes | None: + """ + Check and fix ``ruff`` rule violations. + + All cases will join ``code`` to a single ``str``. + """ + base = "ruff", "check" + if isinstance(code[0], Path): + fp = code[0] + fp.write_text(self._join(code[1:]), encoding="utf-8") + subprocess.run((*base, fp, *self.check_args), check=True) + return None + r = subprocess.run( + (*base, *self.check_args, *self._stdin_args), + input=self._join(code).encode(), + check=True, + capture_output=True, + ) + return r.stdout.decode() if decode else r.stdout + + @overload + def format(self, *code: _Code) -> str: + """Format arbitrarily nested input as a single block.""" + + @overload + def format(self, _target_file: Path, /, *code: None) -> None: + """ + Format an existing file. + + Running on `win32` after writing lines will ensure ``LF`` is used, and not ``CRLF``: + + ruff format --diff --check _target_file + """ + + @overload + def format(self, _encoded_result: bytes, /, *code: None) -> str: + """Format the raw result of ``ruff.check``.""" + + def format(self, *code: Any) -> str | None: + """ + Format some input code, or an existing file. + + Returns decoded result unless formatting an existing file. + """ + base = "ruff", "format" + if len(code) == 1 and isinstance(code[0], Path): + subprocess.run((*base, code[0], *self.format_args), check=True) + return None + encoded = ( + code[0] + if len(code) == 1 and isinstance(code[0], bytes) + else self._join(code).encode() + ) + r = subprocess.run( + (*base, *self.format_args, *self._stdin_args), + input=encoded, + check=True, + capture_output=True, + ) + return r.stdout.decode() + + def __call__(self, *code: _Code) -> str: + return self.format(self.check(code, decode=False)) + + +ruff_inline_docs = Ruff(ignore="E711", skip_magic_traling_comma=True) +ruff = Ruff() diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 3b120bc77..0db8772e8 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -20,6 +20,7 @@ sys.path.insert(0, str(Path.cwd())) +from tools.codemod import ruff from tools.markup import rst_syntax_for_class from tools.schemapi import CodeSnippet, SchemaInfo, arg_kwds, arg_required_kwds, codegen from tools.schemapi.utils import ( @@ -31,8 +32,6 @@ import_typing_extensions, indent_docstring, resolve_references, - ruff_format_py, - ruff_write_lint_format_str, spell_literal, ) from tools.vega_expr import write_expr_module @@ -548,7 +547,7 @@ def copy_schemapi_util() -> None: dest.write(HEADER_COMMENT) dest.writelines(source.readlines()) if sys.platform == "win32": - ruff_format_py(destination_fp) + ruff.format(destination_fp) def recursive_dict_update(schema: dict, root: dict, def_dict: dict) -> None: @@ -1024,7 +1023,7 @@ def vegalite_main(skip_download: bool = False) -> None: f"SCHEMA_VERSION = '{version}'\n", f"SCHEMA_URL = {schema_url(version)!r}\n", ] - ruff_write_lint_format_str(outfile, content) + ruff.write_lint_format(outfile, content) TypeAliasTracer.update_aliases(("Map", "Mapping[str, Any]")) @@ -1106,7 +1105,7 @@ def vegalite_main(skip_download: bool = False) -> None: # Write the pre-generated modules for fp, contents in files.items(): print(f"Writing\n {schemafile!s}\n ->{fp!s}") - ruff_write_lint_format_str(fp, contents) + ruff.write_lint_format(fp, contents) def generate_encoding_artifacts( diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index 6bc7b1f4b..351c355d0 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -4,7 +4,6 @@ import json import re -import subprocess import sys import textwrap import urllib.parse @@ -19,13 +18,13 @@ Iterator, Literal, Mapping, - MutableSequence, Sequence, TypeVar, Union, overload, ) +from tools.codemod import ruff from tools.markup import RSTParseVegaLite, rst_syntax_for_class from tools.schemapi.schemapi import _resolve_references as resolve_references @@ -91,13 +90,6 @@ class _TypeAliasTracer: A format specifier to produce the `TypeAlias` name. Will be provided a `SchemaInfo.title` as a single positional argument. - *ruff_check - Optional [ruff rule codes](https://docs.astral.sh/ruff/rules/), - each prefixed with `--select ` and follow a `ruff check --fix ` call. - - If not provided, uses `[tool.ruff.lint.select]` from `pyproject.toml`. - ruff_format - Optional argument list supplied to [ruff format](https://docs.astral.sh/ruff/formatter/#ruff-format) Attributes ---------- @@ -111,12 +103,7 @@ class _TypeAliasTracer: Prefined import statements to appear at beginning of module. """ - def __init__( - self, - fmt: str = "{}_T", - *ruff_check: str, - ruff_format: Sequence[str] | None = None, - ) -> None: + def __init__(self, fmt: str = "{}_T") -> None: self.fmt: str = fmt self._literals: dict[str, str] = {} self._literals_invert: dict[str, str] = {} @@ -135,10 +122,6 @@ def __init__( import_typing_extensions((3, 10), "TypeAlias"), import_typing_extensions((3, 9), "Annotated", "get_args"), ) - self._cmd_check: list[str] = ["--fix"] - self._cmd_format: Sequence[str] = ruff_format or () - for c in ruff_check: - self._cmd_check.extend(("--extend-select", c)) def _update_literals(self, name: str, tp: str, /) -> None: """Produces an inverted index, to reuse a `Literal` when `SchemaInfo.title` is empty.""" @@ -223,13 +206,6 @@ def write_module( extra `tools.generate_schema_wrapper.TYPING_EXTRA`. """ - ruff_format: MutableSequence[str | Path] = ["ruff", "format", fp] - if self._cmd_format: - ruff_format.extend(self._cmd_format) - commands: tuple[Sequence[str | Path], ...] = ( - ["ruff", "check", fp, *self._cmd_check], - ruff_format, - ) static = (header, "\n", *self._imports, "\n\n") self.update_aliases(*sorted(self._literals.items(), key=itemgetter(0))) all_ = [*iter(self._aliases), *extra_all] @@ -238,10 +214,7 @@ def write_module( [f"__all__ = {all_}", "\n\n", extra], self.generate_aliases(), ) - fp.write_text("\n".join(it), encoding="utf-8") - for cmd in commands: - r = subprocess.run(cmd, check=True) - r.check_returncode() + ruff.write_lint_format(fp, it) @property def n_entries(self) -> int: @@ -997,49 +970,6 @@ def unwrap_literal(tp: str, /) -> str: return re.sub(r"Literal\[(.+)\]", r"\g<1>", tp) -def ruff_format_py(fp: Path, /, *extra_args: str) -> None: - """ - Format an existing file. - - Running on `win32` after writing lines will ensure "lf" is used before: - ```bash - ruff format --diff --check . - ``` - """ - cmd: MutableSequence[str | Path] = ["ruff", "format", fp] - if extra_args: - cmd.extend(extra_args) - r = subprocess.run(cmd, check=True) - r.check_returncode() - - -def ruff_write_lint_format_str( - fp: Path, code: str | Iterable[str], /, *, encoding: str = "utf-8" -) -> None: - """ - Combined steps of writing, `ruff check`, `ruff format`. - - Notes - ----- - - `fp` is written to first, as the size before formatting will be the smallest - - Better utilizes `ruff` performance, rather than `python` str and io - - `code` is no longer bound to `list` - - Encoding set as default - - `I001/2` are `isort` rules, to sort imports. - """ - commands: Iterable[Sequence[str | Path]] = ( - ["ruff", "check", fp, "--fix"], - ["ruff", "check", fp, "--fix", "--select", "I001", "--select", "I002"], - ) - if not isinstance(code, str): - code = "\n".join(code) - fp.write_text(code, encoding=encoding) - for cmd in commands: - r = subprocess.run(cmd, check=True) - r.check_returncode() - ruff_format_py(fp) - - def import_type_checking(*imports: str) -> str: """Write an `if TYPE_CHECKING` block.""" imps = "\n".join(f" {s}" for s in imports) @@ -1066,7 +996,7 @@ def import_typing_extensions( ) -TypeAliasTracer: _TypeAliasTracer = _TypeAliasTracer("{}_T", "I001", "I002") +TypeAliasTracer: _TypeAliasTracer = _TypeAliasTracer("{}_T") """An instance of `_TypeAliasTracer`. Collects a cache of unique `Literal` types used globally. diff --git a/tools/update_init_file.py b/tools/update_init_file.py index c1831093a..0592032f0 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from tools.schemapi.utils import ruff_write_lint_format_str +from tools.codemod import ruff _TYPING_CONSTRUCTS = { te.TypeAlias, @@ -74,7 +74,7 @@ def update__all__variable() -> None: ] # Write new version of altair/__init__.py # Format file content with ruff - ruff_write_lint_format_str(init_path, new_lines) + ruff.write_lint_format(init_path, new_lines) def relevant_attributes(namespace: dict[str, t.Any], /) -> list[str]: diff --git a/tools/vega_expr.py b/tools/vega_expr.py index ce87cb2fb..66d6287fb 100644 --- a/tools/vega_expr.py +++ b/tools/vega_expr.py @@ -29,12 +29,10 @@ overload, ) +from tools.codemod import ruff from tools.markup import RSTParse, Token, read_ast_tokens from tools.markup import RSTRenderer as _RSTRenderer from tools.schemapi.schemapi import SchemaBase as _SchemaBase -from tools.schemapi.utils import ( - ruff_write_lint_format_str as _ruff_write_lint_format_str, -) if TYPE_CHECKING: import sys @@ -977,4 +975,4 @@ def write_expr_module(version: str, output: Path, *, header: str) -> None: [MODULE_POST], ) print(f"Generating\n {url!s}\n ->{output!s}") - _ruff_write_lint_format_str(output, contents) + ruff.write_lint_format(output, contents)