diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/compare.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/compare.py index 906d5710aa333..88b6da20bd861 100644 --- a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/compare.py +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/compare.py @@ -59,3 +59,55 @@ >= c ) ] + +def f(): + return ( + unicodedata.normalize("NFKC", s1).casefold() + == unicodedata.normalize("NFKC", s2).casefold() + ) + +# Call expressions with trailing attributes. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id +) + +ct_match = ( + {aaaaaaaaaaaaaaaa} == self.get_content_type(obj=rel_obj, using=instance._state.db).id +) + +ct_match = ( + (aaaaaaaaaaaaaaaa) == self.get_content_type(obj=rel_obj, using=instance._state.db).id +) + +ct_match = aaaaaaaaaaact_id == self.get_content_type( + obj=rel_obj, using=instance._state.db +) + +# Call expressions with trailing subscripts. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] +) + +ct_match = ( + {aaaaaaaaaaaaaaaa} == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] +) + +ct_match = ( + (aaaaaaaaaaaaaaaa) == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] +) + +# Subscripts expressions with trailing attributes. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type[obj, rel_obj, using, instance._state.db].id +) + +ct_match = ( + {aaaaaaaaaaaaaaaa} == self.get_content_type[obj, rel_obj, using, instance._state.db].id +) + +ct_match = ( + (aaaaaaaaaaaaaaaa) == self.get_content_type[obj, rel_obj, using, instance._state.db].id +) diff --git a/crates/ruff_python_formatter/src/expression/mod.rs b/crates/ruff_python_formatter/src/expression/mod.rs index f11b2a1f0ddb7..971dc0898ebcc 100644 --- a/crates/ruff_python_formatter/src/expression/mod.rs +++ b/crates/ruff_python_formatter/src/expression/mod.rs @@ -178,10 +178,9 @@ impl Format> for MaybeParenthesizeExpression<'_> { Parenthesize::Optional | Parenthesize::IfBreaks => needs_parentheses, }; - let can_omit_optional_parentheses = can_omit_optional_parentheses(expression, f.context()); match needs_parentheses { OptionalParentheses::Multiline if *parenthesize != Parenthesize::IfRequired => { - if can_omit_optional_parentheses { + if can_omit_optional_parentheses(expression, f.context()) { optional_parentheses(&expression.format().with_options(Parentheses::Never)) .fmt(f) } else { @@ -407,9 +406,12 @@ impl<'input> CanOmitOptionalParenthesesVisitor<'input> { attr: _, ctx: _, }) => { + self.visit_expr(value); if has_parentheses(value, self.source) { self.update_max_priority(OperatorPriority::Attribute); } + self.last = Some(expr); + return; } Expr::NamedExpr(_) diff --git a/crates/ruff_python_formatter/tests/snapshots/format@expression__compare.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@expression__compare.py.snap index f62138f5d9b97..91690dd6de18f 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@expression__compare.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@expression__compare.py.snap @@ -65,6 +65,58 @@ return 1 == 2 and ( >= c ) ] + +def f(): + return ( + unicodedata.normalize("NFKC", s1).casefold() + == unicodedata.normalize("NFKC", s2).casefold() + ) + +# Call expressions with trailing attributes. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id +) + +ct_match = ( + {aaaaaaaaaaaaaaaa} == self.get_content_type(obj=rel_obj, using=instance._state.db).id +) + +ct_match = ( + (aaaaaaaaaaaaaaaa) == self.get_content_type(obj=rel_obj, using=instance._state.db).id +) + +ct_match = aaaaaaaaaaact_id == self.get_content_type( + obj=rel_obj, using=instance._state.db +) + +# Call expressions with trailing subscripts. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] +) + +ct_match = ( + {aaaaaaaaaaaaaaaa} == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] +) + +ct_match = ( + (aaaaaaaaaaaaaaaa) == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] +) + +# Subscripts expressions with trailing attributes. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type[obj, rel_obj, using, instance._state.db].id +) + +ct_match = ( + {aaaaaaaaaaaaaaaa} == self.get_content_type[obj, rel_obj, using, instance._state.db].id +) + +ct_match = ( + (aaaaaaaaaaaaaaaa) == self.get_content_type[obj, rel_obj, using, instance._state.db].id +) ``` ## Output @@ -171,6 +223,60 @@ return 1 == 2 and ( >= c ) ] + + +def f(): + return unicodedata.normalize("NFKC", s1).casefold() == unicodedata.normalize( + "NFKC", s2 + ).casefold() + + +# Call expressions with trailing attributes. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id +) + +ct_match = {aaaaaaaaaaaaaaaa} == self.get_content_type( + obj=rel_obj, using=instance._state.db +).id + +ct_match = (aaaaaaaaaaaaaaaa) == self.get_content_type( + obj=rel_obj, using=instance._state.db +).id + +ct_match = aaaaaaaaaaact_id == self.get_content_type( + obj=rel_obj, using=instance._state.db +) + +# Call expressions with trailing subscripts. + +ct_match = ( + aaaaaaaaaaact_id == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] +) + +ct_match = { + aaaaaaaaaaaaaaaa +} == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] + +ct_match = ( + aaaaaaaaaaaaaaaa +) == self.get_content_type(obj=rel_obj, using=instance._state.db)[id] + +# Subscripts expressions with trailing attributes. + +ct_match = ( + aaaaaaaaaaact_id + == self.get_content_type[obj, rel_obj, using, instance._state.db].id +) + +ct_match = { + aaaaaaaaaaaaaaaa +} == self.get_content_type[obj, rel_obj, using, instance._state.db].id + +ct_match = ( + aaaaaaaaaaaaaaaa +) == self.get_content_type[obj, rel_obj, using, instance._state.db].id ```