Skip to content

Commit

Permalink
Fix feature detection for parenthesized context managers (#4104)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Dec 12, 2023
1 parent eb7661f commit ebd543c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 58 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

- Fix bug where `# fmt: off` automatically dedents when used with the `--line-ranges`
option, even when it is not within the specified line range. (#4084)
- Fix feature detection for parenthesized context managers (#4104)

### Preview style

Expand Down
18 changes: 17 additions & 1 deletion src/black/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ def get_features_used( # noqa: C901
if (
len(atom_children) == 3
and atom_children[0].type == token.LPAR
and atom_children[1].type == syms.testlist_gexp
and _contains_asexpr(atom_children[1])
and atom_children[2].type == token.RPAR
):
features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
Expand Down Expand Up @@ -1384,6 +1384,22 @@ def get_features_used( # noqa: C901
return features


def _contains_asexpr(node: Union[Node, Leaf]) -> bool:
"""Return True if `node` contains an as-pattern."""
if node.type == syms.asexpr_test:
return True
elif node.type == syms.atom:
if (
len(node.children) == 3
and node.children[0].type == token.LPAR
and node.children[2].type == token.RPAR
):
return _contains_asexpr(node.children[1])
elif node.type == syms.testlist_gexp:
return any(_contains_asexpr(child) for child in node.children)
return False


def detect_target_versions(
node: Node, *, future_imports: Optional[Set[str]] = None
) -> Set[TargetVersion]:
Expand Down
2 changes: 1 addition & 1 deletion tests/data/cases/pep_572_remove_parens.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# flags: --minimum-version=3.8 --no-preview-line-length-1
# flags: --minimum-version=3.8
if (foo := 0):
pass

Expand Down
130 changes: 74 additions & 56 deletions tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -874,71 +875,88 @@ def test_get_features_used_decorator(self) -> None:
)

def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def f(*, arg,): ...\n")
self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
node = black.lib2to3_parse("f(*arg,)\n")
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
self.check_features_used("def f(*, arg): ...\n", set())
self.check_features_used(
"def f(*, arg,): ...\n", {Feature.TRAILING_COMMA_IN_DEF}
)
node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
node = black.lib2to3_parse("123_456\n")
self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
node = black.lib2to3_parse("123456\n")
self.assertEqual(black.get_features_used(node), set())
self.check_features_used("f(*arg,)\n", {Feature.TRAILING_COMMA_IN_CALL})
self.check_features_used("def f(*, arg): f'string'\n", {Feature.F_STRINGS})
self.check_features_used("123_456\n", {Feature.NUMERIC_UNDERSCORES})
self.check_features_used("123456\n", set())

source, expected = read_data("cases", "function")
node = black.lib2to3_parse(source)
expected_features = {
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
Feature.F_STRINGS,
}
self.assertEqual(black.get_features_used(node), expected_features)
node = black.lib2to3_parse(expected)
self.assertEqual(black.get_features_used(node), expected_features)
self.check_features_used(source, expected_features)
self.check_features_used(expected, expected_features)

source, expected = read_data("cases", "expression")
node = black.lib2to3_parse(source)
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse(expected)
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("lambda a, /, b: ...")
self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
node = black.lib2to3_parse("def fn(a, /, b): ...")
self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
node = black.lib2to3_parse("def fn(): yield a, b")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def fn(): return a, b")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def fn(): yield *b, c")
self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
node = black.lib2to3_parse("def fn(): return a, *b, c")
self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
node = black.lib2to3_parse("x = a, *b, c")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("x: Any = regular")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("x: Any = (regular, regular)")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
self.assertEqual(
black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
self.check_features_used(source, set())
self.check_features_used(expected, set())

self.check_features_used("lambda a, /, b: ...\n", {Feature.POS_ONLY_ARGUMENTS})
self.check_features_used("def fn(a, /, b): ...", {Feature.POS_ONLY_ARGUMENTS})

self.check_features_used("def fn(): yield a, b", set())
self.check_features_used("def fn(): return a, b", set())
self.check_features_used("def fn(): yield *b, c", {Feature.UNPACKING_ON_FLOW})
self.check_features_used(
"def fn(): return a, *b, c", {Feature.UNPACKING_ON_FLOW}
)
node = black.lib2to3_parse("try: pass\nexcept Something: pass")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
node = black.lib2to3_parse("a[*b]")
self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
node = black.lib2to3_parse("a[x, *y(), z] = t")
self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
node = black.lib2to3_parse("def fn(*args: *T): pass")
self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
self.check_features_used("x = a, *b, c", set())

self.check_features_used("x: Any = regular", set())
self.check_features_used("x: Any = (regular, regular)", set())
self.check_features_used("x: Any = Complex(Type(1))[something]", set())
self.check_features_used(
"x: Tuple[int, ...] = a, b, c", {Feature.ANN_ASSIGN_EXTENDED_RHS}
)

self.check_features_used("try: pass\nexcept Something: pass", set())
self.check_features_used("try: pass\nexcept (*Something,): pass", set())
self.check_features_used(
"try: pass\nexcept *Group: pass", {Feature.EXCEPT_STAR}
)

self.check_features_used("a[*b]", {Feature.VARIADIC_GENERICS})
self.check_features_used("a[x, *y(), z] = t", {Feature.VARIADIC_GENERICS})
self.check_features_used("def fn(*args: *T): pass", {Feature.VARIADIC_GENERICS})

self.check_features_used("with a: pass", set())
self.check_features_used("with a, b: pass", set())
self.check_features_used("with a as b: pass", set())
self.check_features_used("with a as b, c as d: pass", set())
self.check_features_used("with (a): pass", set())
self.check_features_used("with (a, b): pass", set())
self.check_features_used("with (a, b) as (c, d): pass", set())
self.check_features_used(
"with (a as b): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
)
self.check_features_used(
"with ((a as b)): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
)
self.check_features_used(
"with (a, b as c): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
)
self.check_features_used(
"with (a, (b as c)): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
)
self.check_features_used(
"with ((a, ((b as c)))): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
)

def check_features_used(self, source: str, expected: Set[Feature]) -> None:
node = black.lib2to3_parse(source)
actual = black.get_features_used(node)
msg = f"Expected {expected} but got {actual} for {source!r}"
try:
self.assertEqual(actual, expected, msg=msg)
except AssertionError:
DebugVisitor.show(node)
raise

def test_get_features_used_for_future_flags(self) -> None:
for src, features in [
Expand Down

0 comments on commit ebd543c

Please sign in to comment.