diff --git a/README.md b/README.md index b7762724..be100ee8 100644 --- a/README.md +++ b/README.md @@ -474,6 +474,17 @@ Availability: ``` +### `subprocess.run`: replace `stdout=subprocess.PIPE, stderr=subprocess.PIPE` with `capture_output=True` + +Availability: +- `--py37-plus` is passed on the commandline. + +```diff +-output = subprocess.run(['foo'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) ++output = subprocess.run(['foo'], capture_output=True) +``` + + ### remove parentheses from `@functools.lru_cache()` Availability: diff --git a/pyupgrade/_plugins/open_mode.py b/pyupgrade/_plugins/open_mode.py index a8cad514..75797f09 100644 --- a/pyupgrade/_plugins/open_mode.py +++ b/pyupgrade/_plugins/open_mode.py @@ -14,6 +14,7 @@ from pyupgrade._data import register from pyupgrade._data import State from pyupgrade._data import TokenFunc +from pyupgrade._token_helpers import delete_argument from pyupgrade._token_helpers import find_open_paren from pyupgrade._token_helpers import parse_call_args @@ -35,10 +36,7 @@ def _fix_open_mode(i: int, tokens: List[Token], *, arg_idx: int) -> None: mode_stripped = mode.split('=')[-1] mode_stripped = mode_stripped.strip().strip('"\'') if mode_stripped in U_MODE_REMOVE: - if arg_idx == 0: - del tokens[func_args[arg_idx][0]: func_args[arg_idx + 1][0]] - else: - del tokens[func_args[arg_idx - 1][1]:func_args[arg_idx][1]] + delete_argument(arg_idx, tokens, func_args) elif mode_stripped in U_MODE_REPLACE_R: new_mode = mode.replace('U', 'r') tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)] diff --git a/pyupgrade/_plugins/subprocess_run.py b/pyupgrade/_plugins/subprocess_run.py new file mode 100644 index 00000000..ebbc08a4 --- /dev/null +++ b/pyupgrade/_plugins/subprocess_run.py @@ -0,0 +1,111 @@ +import ast +import functools +from typing import Iterable +from typing import List +from typing import Tuple + +from tokenize_rt import Offset +from tokenize_rt import Token + +from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import is_name_attr +from pyupgrade._data import register +from pyupgrade._data import State +from pyupgrade._data import TokenFunc +from pyupgrade._token_helpers import delete_argument +from pyupgrade._token_helpers import find_open_paren +from pyupgrade._token_helpers import parse_call_args +from pyupgrade._token_helpers import replace_argument + + +def _use_capture_output( + i: int, + tokens: List[Token], + *, + stdout_arg_idx: int, + stderr_arg_idx: int, +) -> None: + j = find_open_paren(tokens, i) + func_args, _ = parse_call_args(tokens, j) + if stdout_arg_idx < stderr_arg_idx: + delete_argument(stderr_arg_idx, tokens, func_args) + replace_argument( + stdout_arg_idx, + tokens, + func_args, + new='capture_output=True', + ) + else: + replace_argument( + stdout_arg_idx, + tokens, + func_args, + new='capture_output=True', + ) + delete_argument(stderr_arg_idx, tokens, func_args) + + +def _replace_universal_newlines_with_text( + i: int, + tokens: List[Token], + *, + arg_idx: int, +) -> None: + j = find_open_paren(tokens, i) + func_args, _ = parse_call_args(tokens, j) + for i in range(*func_args[arg_idx]): + if tokens[i].src == 'universal_newlines': + tokens[i] = tokens[i]._replace(src='text') + break + else: + raise AssertionError('`universal_newlines` argument not found') + + +@register(ast.Call) +def visit_Call( + state: State, + node: ast.Call, + parent: ast.AST, +) -> Iterable[Tuple[Offset, TokenFunc]]: + if ( + state.settings.min_version >= (3, 7) and + is_name_attr( + node.func, + state.from_imports, + 'subprocess', + ('run',), + ) + ): + stdout_idx = None + stderr_idx = None + universal_newlines_idx = None + for n, keyword in enumerate(node.keywords): + if keyword.arg == 'stdout' and is_name_attr( + keyword.value, + state.from_imports, + 'subprocess', + ('PIPE',), + ): + stdout_idx = n + elif keyword.arg == 'stderr' and is_name_attr( + keyword.value, + state.from_imports, + 'subprocess', + ('PIPE',), + ): + stderr_idx = n + elif keyword.arg == 'universal_newlines': + universal_newlines_idx = n + if universal_newlines_idx is not None: + func = functools.partial( + _replace_universal_newlines_with_text, + arg_idx=len(node.args) + universal_newlines_idx, + ) + yield ast_to_offset(node), func + if stdout_idx is not None and stderr_idx is not None: + func = functools.partial( + _use_capture_output, + stdout_arg_idx=len(node.args) + stdout_idx, + stderr_arg_idx=len(node.args) + stderr_idx, + ) + yield ast_to_offset(node), func diff --git a/pyupgrade/_plugins/universal_newlines_to_text.py b/pyupgrade/_plugins/universal_newlines_to_text.py deleted file mode 100644 index a2172579..00000000 --- a/pyupgrade/_plugins/universal_newlines_to_text.py +++ /dev/null @@ -1,63 +0,0 @@ -import ast -import functools -from typing import Iterable -from typing import List -from typing import Tuple - -from tokenize_rt import Offset -from tokenize_rt import Token - -from pyupgrade._ast_helpers import ast_to_offset -from pyupgrade._ast_helpers import is_name_attr -from pyupgrade._data import register -from pyupgrade._data import State -from pyupgrade._data import TokenFunc -from pyupgrade._token_helpers import find_open_paren -from pyupgrade._token_helpers import parse_call_args - - -def _replace_universal_newlines_with_text( - i: int, - tokens: List[Token], - *, - arg_idx: int, -) -> None: - j = find_open_paren(tokens, i) - func_args, _ = parse_call_args(tokens, j) - for i in range(*func_args[arg_idx]): - if tokens[i].src == 'universal_newlines': - tokens[i] = tokens[i]._replace(src='text') - break - else: - raise AssertionError('`universal_newlines` argument not found') - - -@register(ast.Call) -def visit_Call( - state: State, - node: ast.Call, - parent: ast.AST, -) -> Iterable[Tuple[Offset, TokenFunc]]: - if ( - state.settings.min_version >= (3, 7) and - is_name_attr( - node.func, - state.from_imports, - 'subprocess', - ('run',), - ) - ): - kwarg_idx = next( - ( - n - for n, keyword in enumerate(node.keywords) - if keyword.arg == 'universal_newlines' - ), - None, - ) - if kwarg_idx is not None: - func = functools.partial( - _replace_universal_newlines_with_text, - arg_idx=len(node.args) + kwarg_idx, - ) - yield ast_to_offset(node), func diff --git a/pyupgrade/_token_helpers.py b/pyupgrade/_token_helpers.py index 0ab815c4..eb94bc56 100644 --- a/pyupgrade/_token_helpers.py +++ b/pyupgrade/_token_helpers.py @@ -434,3 +434,32 @@ def replace_name(i: int, tokens: List[Token], *, name: str, new: str) -> None: return j += 1 tokens[i:j + 1] = [new_token] + + +def delete_argument( + i: int, tokens: List[Token], + func_args: Sequence[Tuple[int, int]], +) -> None: + if i == 0: + # delete leading whitespace before next token + end_idx, _ = func_args[i + 1] + while tokens[end_idx].name == 'UNIMPORTANT_WS': + end_idx += 1 + + del tokens[func_args[i][0]:end_idx] + else: + del tokens[func_args[i - 1][1]:func_args[i][1]] + + +def replace_argument( + i: int, + tokens: List[Token], + func_args: Sequence[Tuple[int, int]], + *, + new: str, +) -> None: + start_idx, end_idx = func_args[i] + # don't replace leading whitespace / newlines + while tokens[start_idx].name in {'UNIMPORTANT_WS', 'NL'}: + start_idx += 1 + tokens[start_idx:end_idx] = [Token('SRC', new)] diff --git a/tests/features/capture_output_test.py b/tests/features/capture_output_test.py new file mode 100644 index 00000000..571c89cd --- /dev/null +++ b/tests/features/capture_output_test.py @@ -0,0 +1,116 @@ +import pytest + +from pyupgrade._data import Settings +from pyupgrade._main import _fix_plugins + + +@pytest.mark.parametrize( + ('s', 'version'), + ( + pytest.param( + 'import subprocess\n' + 'subprocess.run(["foo"], stdout=subprocess.PIPE, ' + 'stderr=subprocess.PIPE)\n', + (3,), + id='not Python3.7+', + ), + pytest.param( + 'from foo import run\n' + 'import subprocess\n' + 'run(["foo"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n', + (3, 7), + id='run imported, but not from subprocess', + ), + pytest.param( + 'from foo import PIPE\n' + 'from subprocess import run\n' + 'subprocess.run(["foo"], stdout=PIPE, stderr=PIPE)\n', + (3, 7), + id='PIPE imported, but not from subprocess', + ), + pytest.param( + 'from subprocess import run\n' + 'run(["foo"], stdout=None, stderr=PIPE)\n', + (3, 7), + id='stdout not subprocess.PIPE', + ), + ), +) +def test_fix_capture_output_noop(s, version): + assert _fix_plugins(s, settings=Settings(min_version=version)) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'import subprocess\n' + 'subprocess.run(["foo"], stdout=subprocess.PIPE, ' + 'stderr=subprocess.PIPE)\n', + 'import subprocess\n' + 'subprocess.run(["foo"], capture_output=True)\n', + id='subprocess.run and subprocess.PIPE attributes', + ), + pytest.param( + 'from subprocess import run, PIPE\n' + 'run(["foo"], stdout=PIPE, stderr=PIPE)\n', + 'from subprocess import run, PIPE\n' + 'run(["foo"], capture_output=True)\n', + id='run and PIPE imported from subprocess', + ), + pytest.param( + 'from subprocess import run, PIPE\n' + 'run(["foo"], shell=True, stdout=PIPE, stderr=PIPE)\n', + 'from subprocess import run, PIPE\n' + 'run(["foo"], shell=True, capture_output=True)\n', + id='other argument used too', + ), + pytest.param( + 'import subprocess\n' + 'subprocess.run(["foo"], stderr=subprocess.PIPE, ' + 'stdout=subprocess.PIPE)\n', + 'import subprocess\n' + 'subprocess.run(["foo"], capture_output=True)\n', + id='stderr used before stdout', + ), + pytest.param( + 'import subprocess\n' + 'subprocess.run(stderr=subprocess.PIPE, args=["foo"], ' + 'stdout=subprocess.PIPE)\n', + 'import subprocess\n' + 'subprocess.run(args=["foo"], capture_output=True)\n', + id='stdout is first argument', + ), + pytest.param( + 'import subprocess\n' + 'subprocess.run(\n' + ' stderr=subprocess.PIPE, \n' + ' args=["foo"], \n' + ' stdout=subprocess.PIPE,\n' + ')\n', + 'import subprocess\n' + 'subprocess.run(\n' + ' args=["foo"], \n' + ' capture_output=True,\n' + ')\n', + id='stdout is first argument, multiline', + ), + pytest.param( + 'subprocess.run(\n' + ' "foo",\n' + ' stdout=subprocess.PIPE,\n' + ' stderr=subprocess.PIPE,\n' + ' universal_newlines=True,\n' + ')', + 'subprocess.run(\n' + ' "foo",\n' + ' capture_output=True,\n' + ' text=True,\n' + ')', + id='both universal_newlines and capture_output rewrite', + ), + ), +) +def test_fix_capture_output(s, expected): + ret = _fix_plugins(s, settings=Settings(min_version=(3, 7))) + assert ret == expected diff --git a/tests/features/open_mode_test.py b/tests/features/open_mode_test.py index bffc854b..c1ea2c60 100644 --- a/tests/features/open_mode_test.py +++ b/tests/features/open_mode_test.py @@ -60,7 +60,7 @@ def test_fix_open_mode_noop(s): ), ( 'open(mode="r", encoding="UTF-8", file="t.py")', - 'open( encoding="UTF-8", file="t.py")', + 'open(encoding="UTF-8", file="t.py")', ), ), )