Skip to content
This repository has been archived by the owner on Jul 11, 2022. It is now read-only.

Commit

Permalink
Only use trailing commas in function signatures when it's safe
Browse files Browse the repository at this point in the history
Trailing commas after * or ** in a function signature are only safe for Python 3.6
code.  So now Black checks whether the file was already Python 3.6 to begin
with.  If so, trailing commas are used in such cases.  Otherwise, they're not.

When * and ** don't appear in a function signature, the trailing comma is
always safe.

Fixes pytest-dev#8
  • Loading branch information
ambv authored and Lukasz Langa committed Mar 16, 2018
1 parent c26daa4 commit 5fb5cc8
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 19 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ You can still try but prepare to be disappointed.

* added `--check`

* only put trailing commas in function signatures and calls if it's
safe to do so. If the file is Python 3.6+ it's always safe, otherwise
only safe if there are no `*args` or `**kwargs` used in the signature
or call. (#8)

* fixed invalid spacing of dots in relative imports (#6, #13)

* fixed invalid splitting after comma on unpacked variables in for-loops
Expand Down
62 changes: 54 additions & 8 deletions black.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from pathlib import Path
import tokenize
import sys
from typing import (
Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
)
Expand Down Expand Up @@ -192,6 +193,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
comments: List[Line] = []
lines = LineGenerator()
elt = EmptyLineTracker()
py36 = is_python36(src_node)
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
Expand All @@ -204,7 +206,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
for comment in comments:
dst_contents += str(comment)
comments = []
for line in split_line(current_line, line_length=line_length):
for line in split_line(current_line, line_length=line_length, py36=py36):
dst_contents += str(line)
else:
comments.append(current_line)
Expand Down Expand Up @@ -1108,13 +1110,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
yield Leaf(STANDALONE_COMMENT, line)


def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
def split_line(
line: Line, line_length: int, inner: bool = False, py36: bool = False
) -> Iterator[Line]:
"""Splits a `line` into potentially many lines.
They should fit in the allotted `line_length` but might not be able to.
`inner` signifies that there were a pair of brackets somewhere around the
current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results.
If `py36` is True, splitting may generate syntax that is only compatible
with Python 3.6 and later.
"""
line_str = str(line).strip('\n')
if len(line_str) <= line_length and '\n' not in line_str:
Expand All @@ -1137,11 +1144,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li
# split altogether.
result: List[Line] = []
try:
for l in split_func(line):
for l in split_func(line, py36=py36):
if str(l).strip('\n') == line_str:
raise CannotSplit("Split function returned an unchanged result")

result.extend(split_line(l, line_length=line_length, inner=True))
result.extend(
split_line(l, line_length=line_length, inner=True, py36=py36)
)
except CannotSplit as cs:
continue

Expand All @@ -1153,7 +1162,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li
yield line


def left_hand_split(line: Line) -> Iterator[Line]:
def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions.
Expand Down Expand Up @@ -1208,7 +1217,7 @@ def left_hand_split(line: Line) -> Iterator[Line]:
yield result


def right_hand_split(line: Line) -> Iterator[Line]:
def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair."""
head = Line(depth=line.depth)
body = Line(depth=line.depth + 1, inside_brackets=True)
Expand Down Expand Up @@ -1259,10 +1268,12 @@ def right_hand_split(line: Line) -> Iterator[Line]:
yield result


def delimiter_split(line: Line) -> Iterator[Line]:
def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
"""Split according to delimiters of the highest priority.
This kind of split doesn't increase indentation.
If `py36` is True, the split will add trailing commas also in function
signatures that contain * and **.
"""
try:
last_leaf = line.leaves[-1]
Expand All @@ -1276,11 +1287,20 @@ def delimiter_split(line: Line) -> Iterator[Line]:
raise CannotSplit("No delimiters found")

current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
lowest_depth = sys.maxsize
trailing_comma_safe = True
for leaf in line.leaves:
current_line.append(leaf, preformatted=True)
comment_after = line.comments.get(id(leaf))
if comment_after:
current_line.append(comment_after, preformatted=True)
lowest_depth = min(lowest_depth, leaf.bracket_depth)
if (
leaf.bracket_depth == lowest_depth and # type: ignore
leaf.type == token.STAR or
leaf.type == token.DOUBLESTAR
):
trailing_comma_safe = trailing_comma_safe and py36
leaf_priority = delimiters.get(id(leaf))
if leaf_priority == delimiter_priority:
normalize_prefix(current_line.leaves[0])
Expand All @@ -1290,7 +1310,8 @@ def delimiter_split(line: Line) -> Iterator[Line]:
if current_line:
if (
delimiter_priority == COMMA_PRIORITY and
current_line.leaves[-1].type != token.COMMA
current_line.leaves[-1].type != token.COMMA and
trailing_comma_safe
):
current_line.append(Leaf(token.COMMA, ','))
normalize_prefix(current_line.leaves[0])
Expand Down Expand Up @@ -1325,6 +1346,31 @@ def normalize_prefix(leaf: Leaf) -> None:
leaf.prefix = ''


def is_python36(node: Node) -> bool:
"""Returns True if the current file is using Python 3.6+ features.
Currently looking for:
- f-strings; and
- trailing commas after * or ** in function signatures.
"""
for n in node.pre_order():
if n.type == token.STRING:
assert isinstance(n, Leaf)
if n.value[:2] in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
return True

elif (
n.type == syms.typedargslist and
n.children and
n.children[-1].type == token.COMMA
):
for ch in n.children:
if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
return True

return False


PYTHON_EXTENSIONS = {'.py'}
BLACKLISTED_DIRECTORIES = {
'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
Expand Down
18 changes: 8 additions & 10 deletions tests/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
call(kwarg='hey')
call(arg, kwarg='hey')
call(arg, another, kwarg='hey', **kwargs)
call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs) # note: no trailing comma pre-3.6
lukasz.langa.pl
call.me(maybe)
1 .real
Expand All @@ -88,11 +89,6 @@
slice[1:]
slice[::-1]
(str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
f'f-string without formatted values is just a string'
f'{{NOT a formatted value}}'
f'some f-string with {a} {few():.2f} {formatted.values!r}'
f"{f'{nested} inner'} outer"
f'space between opening braces: { {a for a in (1, 2, 3)}}'
{'2.7': dead, '3.7': long_live or die_hard}
{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
Expand Down Expand Up @@ -200,6 +196,13 @@ async def f():
call(kwarg='hey')
call(arg, kwarg='hey')
call(arg, another, kwarg='hey', **kwargs)
call(
this_is_a_very_long_variable_which_will_force_a_delimiter_split,
arg,
another,
kwarg='hey',
**kwargs
) # note: no trailing comma pre-3.6
lukasz.langa.pl
call.me(maybe)
1 .real
Expand All @@ -217,11 +220,6 @@ async def f():
slice[1:]
slice[::-1]
(str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
f'f-string without formatted values is just a string'
f'{{NOT a formatted value}}'
f'some f-string with {a} {few():.2f} {formatted.values!r}'
f"{f'{nested} inner'} outer"
f'space between opening braces: { {a for a in (1, 2, 3)}}'
{'2.7': dead, '3.7': long_live or die_hard}
{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
Expand Down
5 changes: 5 additions & 0 deletions tests/fstring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
f'f-string without formatted values is just a string'
f'{{NOT a formatted value}}'
f'some f-string with {a} {few():.2f} {formatted.values!r}'
f"{f'{nested} inner'} outer"
f'space between opening braces: { {a for a in (1, 2, 3)}}'
4 changes: 3 additions & 1 deletion tests/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from library import some_connection, \
some_decorator

f'trigger 3.6 mode'
def func_no_args():
a; b; c
if True: raise RuntimeError
Expand Down Expand Up @@ -71,6 +71,8 @@ def long_lines():

from library import some_connection, some_decorator

f'trigger 3.6 mode'


def func_no_args():
a
Expand Down
26 changes: 26 additions & 0 deletions tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def test_expression(self) -> None:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)

@patch("black.dump_to_file", dump_to_stderr)
def test_fstring(self) -> None:
source, expected = read_data('fstring')
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)

@patch("black.dump_to_file", dump_to_stderr)
def test_comments(self) -> None:
source, expected = read_data('comments')
Expand Down Expand Up @@ -215,6 +223,24 @@ def err(msg: str, **kwargs):
)
self.assertEqual(report.return_code, 123)

def test_is_python36(self):
node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertFalse(black.is_python36(node))
node = black.lib2to3_parse("def f(*, arg,): ...\n")
self.assertTrue(black.is_python36(node))
node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertTrue(black.is_python36(node))
source, expected = read_data('function')
node = black.lib2to3_parse(source)
self.assertTrue(black.is_python36(node))
node = black.lib2to3_parse(expected)
self.assertTrue(black.is_python36(node))
source, expected = read_data('expression')
node = black.lib2to3_parse(source)
self.assertFalse(black.is_python36(node))
node = black.lib2to3_parse(expected)
self.assertFalse(black.is_python36(node))


if __name__ == '__main__':
unittest.main()

0 comments on commit 5fb5cc8

Please sign in to comment.