From bf9469fd9c2bf1218aec6a68982b5e45a19cdc22 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 9 Mar 2024 17:42:29 -0800 Subject: [PATCH] Fix AST safety check false negative (#4270) Fixes #4268 Previously we would allow whitespace changes in all strings, now only in docstrings. Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com> --- CHANGES.md | 4 ++ src/black/__init__.py | 15 ++++-- src/black/parsing.py | 42 ++++++++++++--- tests/test_black.py | 122 +++++++++++++++++++++++++++++++++++++----- 4 files changed, 156 insertions(+), 27 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index c4f1d1da16f..33d43125c25 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,10 @@ - Don't move comments along with delimiters, which could cause crashes (#4248) +- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions + of Black would incorrectly format the contents of certain unusual f-strings containing + nested strings with the same quote type. Now, Black will crash on such strings until + support for the new f-string syntax is implemented. (#4270) ### Preview style diff --git a/src/black/__init__.py b/src/black/__init__.py index f82b9fec5b7..da884e6027e 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -77,8 +77,13 @@ syms, ) from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out -from black.parsing import InvalidInput # noqa F401 -from black.parsing import lib2to3_parse, parse_ast, stringify_ast +from black.parsing import ( # noqa F401 + ASTSafetyError, + InvalidInput, + lib2to3_parse, + parse_ast, + stringify_ast, +) from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges from black.report import Changed, NothingChanged, Report from black.trans import iter_fexpr_spans @@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None: try: src_ast = parse_ast(src) except Exception as exc: - raise AssertionError( + raise ASTSafetyError( "cannot use --safe with this file; failed to parse source file AST: " f"{exc}\n" "This could be caused by running Black with an older Python version " @@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None: dst_ast = parse_ast(dst) except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) - raise AssertionError( + raise ASTSafetyError( f"INTERNAL ERROR: Black produced invalid code: {exc}. " "Please report a bug on https://github.com/psf/black/issues. " f"This invalid output might be helpful: {log}" @@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None: dst_ast_str = "\n".join(stringify_ast(dst_ast)) if src_ast_str != dst_ast_str: log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) - raise AssertionError( + raise ASTSafetyError( "INTERNAL ERROR: Black produced code that is not equivalent to the" " source. Please report a bug on " f"https://github.com/psf/black/issues. This diff might be helpful: {log}" diff --git a/src/black/parsing.py b/src/black/parsing.py index 63c5e71a0fe..aa97a8cecea 100644 --- a/src/black/parsing.py +++ b/src/black/parsing.py @@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str: return code +class ASTSafetyError(Exception): + """Raised when Black's generated code is not equivalent to the old AST.""" + + def _parse_single_version( src: str, version: Tuple[int, int], *, type_comments: bool ) -> ast.AST: @@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str: return normalized.strip() -def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]: +def stringify_ast(node: ast.AST) -> Iterator[str]: """Simple visitor generating strings to compare ASTs by content.""" + return _stringify_ast(node, []) + +def _stringify_ast_with_new_parent( + node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST +) -> Iterator[str]: + parent_stack.append(new_parent) + yield from _stringify_ast(node, parent_stack) + parent_stack.pop() + + +def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]: if ( isinstance(node, ast.Constant) and isinstance(node.value, str) @@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]: # over the kind node.kind = None - yield f"{' ' * depth}{node.__class__.__name__}(" + yield f"{' ' * len(parent_stack)}{node.__class__.__name__}(" for field in sorted(node._fields): # noqa: F402 # TypeIgnore has only one field 'lineno' which breaks this comparison @@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]: except AttributeError: continue - yield f"{' ' * (depth + 1)}{field}=" + yield f"{' ' * (len(parent_stack) + 1)}{field}=" if isinstance(value, list): for item in value: @@ -191,13 +206,15 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]: and isinstance(item, ast.Tuple) ): for elt in item.elts: - yield from stringify_ast(elt, depth + 2) + yield from _stringify_ast_with_new_parent( + elt, parent_stack, node + ) elif isinstance(item, ast.AST): - yield from stringify_ast(item, depth + 2) + yield from _stringify_ast_with_new_parent(item, parent_stack, node) elif isinstance(value, ast.AST): - yield from stringify_ast(value, depth + 2) + yield from _stringify_ast_with_new_parent(value, parent_stack, node) else: normalized: object @@ -205,6 +222,12 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]: isinstance(node, ast.Constant) and field == "value" and isinstance(value, str) + and len(parent_stack) >= 2 + and isinstance(parent_stack[-1], ast.Expr) + and isinstance( + parent_stack[-2], + (ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef), + ) ): # Constant strings may be indented across newlines, if they are # docstrings; fold spaces after newlines when comparing. Similarly, @@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]: normalized = value.rstrip() else: normalized = value - yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}" + yield ( + f"{' ' * (len(parent_stack) + 1)}{normalized!r}, #" + f" {value.__class__.__name__}" + ) - yield f"{' ' * depth}) # /{node.__class__.__name__}" + yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}" diff --git a/tests/test_black.py b/tests/test_black.py index 41f87cd16f8..96f53d5e5f3 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -46,6 +46,7 @@ from black.debug import DebugVisitor from black.mode import Mode, Preview from black.output import color_diff, diff +from black.parsing import ASTSafetyError from black.report import Report # Import other test classes @@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None: ff(test_file, write_back=black.WriteBack.YES) self.assertEqual(test_file.read_bytes(), expected) - def test_assert_equivalent_different_asts(self) -> None: - with self.assertRaises(AssertionError): - black.assert_equivalent("{}", "None") - def test_root_logger_not_used_directly(self) -> None: def fail(*args: Any, **kwargs: Any) -> None: self.fail("Record created with root logger") @@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None: exc_info.match("Cannot parse: 2:0: EOF in multi-line statement") - def test_equivalency_ast_parse_failure_includes_error(self) -> None: - with pytest.raises(AssertionError) as err: - black.assert_equivalent("a«»a = 1", "a«»a = 1") - - err.match("--safe") - # Unfortunately the SyntaxError message has changed in newer versions so we - # can't match it directly. - err.match("invalid character") - err.match(r"\(, line 1\)") - def test_line_ranges_with_code_option(self) -> None: code = textwrap.dedent("""\ if a == b: @@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None: black.format_file_contents("x = 1\n", fast=True, mode=black.Mode()) +class TestASTSafety(BlackBaseTestCase): + def check_ast_equivalence( + self, source: str, dest: str, *, should_fail: bool = False + ) -> None: + # If we get a failure, make sure it's not because the code itself + # is invalid, since that will also cause assert_equivalent() to throw + # ASTSafetyError. + source = textwrap.dedent(source) + dest = textwrap.dedent(dest) + black.parse_ast(source) + black.parse_ast(dest) + if should_fail: + with self.assertRaises(ASTSafetyError): + black.assert_equivalent(source, dest) + else: + black.assert_equivalent(source, dest) + + def test_assert_equivalent_basic(self) -> None: + self.check_ast_equivalence("{}", "None", should_fail=True) + self.check_ast_equivalence("1+2", "1 + 2") + self.check_ast_equivalence("hi # comment", "hi") + + def test_assert_equivalent_del(self) -> None: + self.check_ast_equivalence("del (a, b)", "del a, b") + + def test_assert_equivalent_strings(self) -> None: + self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True) + self.check_ast_equivalence( + ''' + """docstring """ + ''', + ''' + """docstring""" + ''', + ) + self.check_ast_equivalence( + ''' + """docstring """ + ''', + ''' + """ddocstring""" + ''', + should_fail=True, + ) + self.check_ast_equivalence( + ''' + class A: + """ + + docstring + + + """ + ''', + ''' + class A: + """docstring""" + ''', + ) + self.check_ast_equivalence( + """ + def f(): + " docstring " + """, + ''' + def f(): + """docstring""" + ''', + ) + self.check_ast_equivalence( + """ + async def f(): + " docstring " + """, + ''' + async def f(): + """docstring""" + ''', + ) + + def test_assert_equivalent_fstring(self) -> None: + major, minor = sys.version_info[:2] + if major < 3 or (major == 3 and minor < 12): + pytest.skip("relies on 3.12+ syntax") + # https://github.com/psf/black/issues/4268 + self.check_ast_equivalence( + """print(f"{"|".join([a,b,c])}")""", + """print(f"{" | ".join([a,b,c])}")""", + should_fail=True, + ) + self.check_ast_equivalence( + """print(f"{"|".join(['a','b','c'])}")""", + """print(f"{" | ".join(['a','b','c'])}")""", + should_fail=True, + ) + + def test_equivalency_ast_parse_failure_includes_error(self) -> None: + with pytest.raises(ASTSafetyError) as err: + black.assert_equivalent("a«»a = 1", "a«»a = 1") + + err.match("--safe") + # Unfortunately the SyntaxError message has changed in newer versions so we + # can't match it directly. + err.match("invalid character") + err.match(r"\(, line 1\)") + + try: with open(black.__file__, "r", encoding="utf-8") as _bf: black_source_lines = _bf.readlines()