diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/return.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/return.py new file mode 100644 index 0000000000000..96960b4c3cbaf --- /dev/null +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/return.py @@ -0,0 +1,16 @@ + +return len(self.nodeseeeeeeeee), sum( + len(node.parents) for node in self.node_map.values() +) + + +return len(self.nodeseeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee), sum( + len(node.parents) for node in self.node_map.values() +) + + +return ( + len(self.nodeseeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee), sum( + len(node.parents) for node in self.node_map.values() + ) +) diff --git a/crates/ruff_python_formatter/src/expression/expr_tuple.rs b/crates/ruff_python_formatter/src/expression/expr_tuple.rs index 4e4284419b229..f62d40f138c90 100644 --- a/crates/ruff_python_formatter/src/expression/expr_tuple.rs +++ b/crates/ruff_python_formatter/src/expression/expr_tuple.rs @@ -31,6 +31,20 @@ pub enum TupleParentheses { /// ``` Preserve, + /// The same as [`Self::Default`] except that it uses [`optional_parentheses`] rather than + /// [`parenthesize_if_expands`]. This avoids adding parentheses if breaking any containing parenthesized + /// expression makes the tuple fit. + /// + /// Avoids adding parentheses around the tuple because breaking the `sum` call expression is sufficient + /// to make it fit. + /// + /// ```python + /// return len(self.nodeseeeeeeeee), sum( + // len(node.parents) for node in self.node_map.values() + // ) + /// ``` + OptionalParentheses, + /// Handle the special cases where we don't include parentheses at all. /// /// Black never formats tuple targets of for loops with parentheses if inside a comprehension. @@ -158,7 +172,7 @@ impl FormatNodeRule for FormatExprTuple { .finish() } TupleParentheses::Preserve => group(&ExprSequence::new(item)).fmt(f), - TupleParentheses::NeverPreserve => { + TupleParentheses::NeverPreserve | TupleParentheses::OptionalParentheses => { optional_parentheses(&ExprSequence::new(item)).fmt(f) } TupleParentheses::Default => { diff --git a/crates/ruff_python_formatter/src/statement/stmt_return.rs b/crates/ruff_python_formatter/src/statement/stmt_return.rs index 61beea83e6678..c6379b5a79da5 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_return.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_return.rs @@ -1,7 +1,8 @@ use ruff_formatter::write; -use ruff_python_ast::StmtReturn; +use ruff_python_ast::{Expr, StmtReturn}; use crate::comments::{SourceComment, SuppressionKind}; +use crate::expression::expr_tuple::TupleParentheses; use crate::expression::maybe_parenthesize_expression; use crate::expression::parentheses::Parenthesize; use crate::prelude::*; @@ -12,17 +13,31 @@ pub struct FormatStmtReturn; impl FormatNodeRule for FormatStmtReturn { fn fmt_fields(&self, item: &StmtReturn, f: &mut PyFormatter) -> FormatResult<()> { let StmtReturn { range: _, value } = item; - if let Some(value) = value { - write!( - f, - [ - text("return"), - space(), - maybe_parenthesize_expression(value, item, Parenthesize::IfBreaks) - ] - ) - } else { - text("return").fmt(f) + + text("return").fmt(f)?; + + match value.as_deref() { + Some(Expr::Tuple(tuple)) if !f.context().comments().has_leading(tuple) => { + write!( + f, + [ + space(), + tuple + .format() + .with_options(TupleParentheses::OptionalParentheses) + ] + ) + } + Some(value) => { + write!( + f, + [ + space(), + maybe_parenthesize_expression(value, item, Parenthesize::IfBreaks) + ] + ) + } + None => Ok(()), } } diff --git a/crates/ruff_python_formatter/tests/snapshots/format@expression__binary_implicit_string.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@expression__binary_implicit_string.py.snap index a87ae3d9d87c0..207b8772bab9a 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@expression__binary_implicit_string.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@expression__binary_implicit_string.py.snap @@ -252,33 +252,23 @@ self.assertEqual( def test(): return ( - ( - "((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -" - " (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))" - ) % {"lhs": lhs_sql, "rhs": rhs_sql}, - tuple(lhs_params) * 2 + tuple(rhs_params) * 2, - ) + "((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -" + " (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))" + ) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(rhs_params) * 2 def test2(): - return ( - "RETURNING %s INTO %s" - % ( - ", ".join(field_names), - ", ".join(["%s"] * len(params)), - ), - tuple(params), - ) + return "RETURNING %s INTO %s" % ( + ", ".join(field_names), + ", ".join(["%s"] * len(params)), + ), tuple(params) def test3(): return ( - ( - "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " - "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" - ) % (lhs, datatype_values, lhs, lhs), - (tuple(params) + (json_path,)) * 3, - ) + "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " + "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" + ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3 ``` diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__return.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__return.py.snap new file mode 100644 index 0000000000000..15f9c6a59c223 --- /dev/null +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__return.py.snap @@ -0,0 +1,44 @@ +--- +source: crates/ruff_python_formatter/tests/fixtures.rs +input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/return.py +--- +## Input +```py + +return len(self.nodeseeeeeeeee), sum( + len(node.parents) for node in self.node_map.values() +) + + +return len(self.nodeseeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee), sum( + len(node.parents) for node in self.node_map.values() +) + + +return ( + len(self.nodeseeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee), sum( + len(node.parents) for node in self.node_map.values() + ) +) +``` + +## Output +```py +return len(self.nodeseeeeeeeee), sum( + len(node.parents) for node in self.node_map.values() +) + + +return len( + self.nodeseeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee +), sum(len(node.parents) for node in self.node_map.values()) + + +return ( + len(self.nodeseeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee), + sum(len(node.parents) for node in self.node_map.values()), +) +``` + + +