From 52630a1d5539dfa78738544657c7bff968842624 Mon Sep 17 00:00:00 2001 From: Dylan <53534755+dylwil3@users.noreply.github.com> Date: Mon, 5 Aug 2024 21:30:58 -0500 Subject: [PATCH] [`flake8-comprehensions`] Set comprehensions not a violation for `sum` in `unnecessary-comprehension-in-call` (`C419`) (#12691) ## Summary Removes set comprehension as a violation for `sum` when checking `C419`, because set comprehension may de-duplicate entries in a generator, thereby modifying the value of the sum. Closes #12690. --- .../fixtures/flake8_comprehensions/C419.py | 14 +++ .../unnecessary_comprehension_in_call.rs | 104 +++++++++++++++--- ...8_comprehensions__tests__C419_C419.py.snap | 58 ++++++++-- ...sions__tests__preview__C419_C419_1.py.snap | 10 +- 4 files changed, 160 insertions(+), 26 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C419.py b/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C419.py index b0a15cf2d6aac..311364095af1e 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C419.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C419.py @@ -41,3 +41,17 @@ async def f() -> bool: i.bit_count() for i in range(5) # rbracket comment ] # rpar comment ) + +## Set comprehensions should only be linted +## when function is invariant under duplication of inputs + +# should be linted... +any({x.id for x in bar}) +all({x.id for x in bar}) + +# should be linted in preview... +min({x.id for x in bar}) +max({x.id for x in bar}) + +# should not be linted... +sum({x.id for x in bar}) diff --git a/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_comprehension_in_call.rs b/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_comprehension_in_call.rs index 0ce5f88f1a3ca..6897e224f3bbe 100644 --- a/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_comprehension_in_call.rs +++ b/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_comprehension_in_call.rs @@ -1,17 +1,18 @@ -use ruff_python_ast::{self as ast, Expr, Keyword}; - use ruff_diagnostics::{Diagnostic, FixAvailability}; use ruff_diagnostics::{Edit, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::any_over_expr; +use ruff_python_ast::{self as ast, Expr, Keyword}; use ruff_text_size::{Ranged, TextSize}; use crate::checkers::ast::Checker; - use crate::rules::flake8_comprehensions::fixes; /// ## What it does -/// Checks for unnecessary list comprehensions passed to builtin functions that take an iterable. +/// Checks for unnecessary list or set comprehensions passed to builtin functions that take an iterable. +/// +/// Set comprehensions are only a violation in the case where the builtin function does not care about +/// duplication of elements in the passed iterable. /// /// ## Why is this bad? /// Many builtin functions (this rule currently covers `any` and `all` in stable, along with `min`, @@ -65,18 +66,23 @@ use crate::rules::flake8_comprehensions::fixes; /// /// [preview]: https://docs.astral.sh/ruff/preview/ #[violation] -pub struct UnnecessaryComprehensionInCall; +pub struct UnnecessaryComprehensionInCall { + comprehension_kind: ComprehensionKind, +} impl Violation for UnnecessaryComprehensionInCall { const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes; #[derive_message_formats] fn message(&self) -> String { - format!("Unnecessary list comprehension") + match self.comprehension_kind { + ComprehensionKind::List => format!("Unnecessary list comprehension"), + ComprehensionKind::Set => format!("Unnecessary set comprehension"), + } } fn fix_title(&self) -> Option { - Some("Remove unnecessary list comprehension".to_string()) + Some("Remove unnecessary comprehension".to_string()) } } @@ -102,18 +108,42 @@ pub(crate) fn unnecessary_comprehension_in_call( if contains_await(elt) { return; } - let Some(builtin_function) = checker.semantic().resolve_builtin_symbol(func) else { + let Some(Ok(builtin_function)) = checker + .semantic() + .resolve_builtin_symbol(func) + .map(SupportedBuiltins::try_from) + else { return; }; - if !(matches!(builtin_function, "any" | "all") - || (checker.settings.preview.is_enabled() - && matches!(builtin_function, "sum" | "min" | "max"))) + if !(matches!( + builtin_function, + SupportedBuiltins::Any | SupportedBuiltins::All + ) || (checker.settings.preview.is_enabled() + && matches!( + builtin_function, + SupportedBuiltins::Sum | SupportedBuiltins::Min | SupportedBuiltins::Max + ))) { return; } - let mut diagnostic = Diagnostic::new(UnnecessaryComprehensionInCall, arg.range()); - + let mut diagnostic = match (arg, builtin_function.duplication_variance()) { + (Expr::ListComp(_), _) => Diagnostic::new( + UnnecessaryComprehensionInCall { + comprehension_kind: ComprehensionKind::List, + }, + arg.range(), + ), + (Expr::SetComp(_), DuplicationVariance::Invariant) => Diagnostic::new( + UnnecessaryComprehensionInCall { + comprehension_kind: ComprehensionKind::Set, + }, + arg.range(), + ), + _ => { + return; + } + }; if args.len() == 1 { // If there's only one argument, remove the list or set brackets. diagnostic.try_set_fix(|| { @@ -144,3 +174,51 @@ pub(crate) fn unnecessary_comprehension_in_call( fn contains_await(expr: &Expr) -> bool { any_over_expr(expr, &Expr::is_await_expr) } + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum DuplicationVariance { + Invariant, + Variant, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ComprehensionKind { + List, + Set, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum SupportedBuiltins { + All, + Any, + Sum, + Min, + Max, +} + +impl TryFrom<&str> for SupportedBuiltins { + type Error = &'static str; + + fn try_from(value: &str) -> Result { + match value { + "all" => Ok(Self::All), + "any" => Ok(Self::Any), + "sum" => Ok(Self::Sum), + "min" => Ok(Self::Min), + "max" => Ok(Self::Max), + _ => Err("Unsupported builtin for `unnecessary-comprehension-in-call`"), + } + } +} + +impl SupportedBuiltins { + fn duplication_variance(self) -> DuplicationVariance { + match self { + SupportedBuiltins::All + | SupportedBuiltins::Any + | SupportedBuiltins::Min + | SupportedBuiltins::Max => DuplicationVariance::Invariant, + SupportedBuiltins::Sum => DuplicationVariance::Variant, + } + } +} diff --git a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C419_C419.py.snap b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C419_C419.py.snap index 4f47e3af10fe2..d1b04aebaa8bf 100644 --- a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C419_C419.py.snap +++ b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C419_C419.py.snap @@ -8,7 +8,7 @@ C419.py:1:5: C419 [*] Unnecessary list comprehension 2 | all([x.id for x in bar]) 3 | any( # first comment | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 1 |-any([x.id for x in bar]) @@ -25,7 +25,7 @@ C419.py:2:5: C419 [*] Unnecessary list comprehension 3 | any( # first comment 4 | [x.id for x in bar], # second comment | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 1 1 | any([x.id for x in bar]) @@ -44,7 +44,7 @@ C419.py:4:5: C419 [*] Unnecessary list comprehension 5 | ) # third comment 6 | all( # first comment | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 1 1 | any([x.id for x in bar]) @@ -65,7 +65,7 @@ C419.py:7:5: C419 [*] Unnecessary list comprehension 8 | ) # third comment 9 | any({x.id for x in bar}) | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 4 4 | [x.id for x in bar], # second comment @@ -77,7 +77,7 @@ C419.py:7:5: C419 [*] Unnecessary list comprehension 9 9 | any({x.id for x in bar}) 10 10 | -C419.py:9:5: C419 [*] Unnecessary list comprehension +C419.py:9:5: C419 [*] Unnecessary set comprehension | 7 | [x.id for x in bar], # second comment 8 | ) # third comment @@ -86,7 +86,7 @@ C419.py:9:5: C419 [*] Unnecessary list comprehension 10 | 11 | # OK | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 6 6 | all( # first comment @@ -113,7 +113,7 @@ C419.py:28:5: C419 [*] Unnecessary list comprehension 34 | # trailing comment 35 | ) | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 25 25 | @@ -145,7 +145,7 @@ C419.py:39:5: C419 [*] Unnecessary list comprehension | |_____^ C419 43 | ) | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 36 36 | @@ -160,3 +160,45 @@ C419.py:39:5: C419 [*] Unnecessary list comprehension 41 |+# second line comment 42 |+i.bit_count() for i in range(5) # rbracket comment # rpar comment 43 43 | ) +44 44 | +45 45 | ## Set comprehensions should only be linted + +C419.py:49:5: C419 [*] Unnecessary set comprehension + | +48 | # should be linted... +49 | any({x.id for x in bar}) + | ^^^^^^^^^^^^^^^^^^^ C419 +50 | all({x.id for x in bar}) + | + = help: Remove unnecessary comprehension + +ℹ Unsafe fix +46 46 | ## when function is invariant under duplication of inputs +47 47 | +48 48 | # should be linted... +49 |-any({x.id for x in bar}) + 49 |+any(x.id for x in bar) +50 50 | all({x.id for x in bar}) +51 51 | +52 52 | # should be linted in preview... + +C419.py:50:5: C419 [*] Unnecessary set comprehension + | +48 | # should be linted... +49 | any({x.id for x in bar}) +50 | all({x.id for x in bar}) + | ^^^^^^^^^^^^^^^^^^^ C419 +51 | +52 | # should be linted in preview... + | + = help: Remove unnecessary comprehension + +ℹ Unsafe fix +47 47 | +48 48 | # should be linted... +49 49 | any({x.id for x in bar}) +50 |-all({x.id for x in bar}) + 50 |+all(x.id for x in bar) +51 51 | +52 52 | # should be linted in preview... +53 53 | min({x.id for x in bar}) diff --git a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__preview__C419_C419_1.py.snap b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__preview__C419_C419_1.py.snap index 1c30178ac47d2..9bc26685fbd88 100644 --- a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__preview__C419_C419_1.py.snap +++ b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__preview__C419_C419_1.py.snap @@ -8,7 +8,7 @@ C419_1.py:1:5: C419 [*] Unnecessary list comprehension 2 | min([x.val for x in bar]) 3 | max([x.val for x in bar]) | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 1 |-sum([x.val for x in bar]) @@ -25,7 +25,7 @@ C419_1.py:2:5: C419 [*] Unnecessary list comprehension 3 | max([x.val for x in bar]) 4 | sum([x.val for x in bar], 0) | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 1 1 | sum([x.val for x in bar]) @@ -43,7 +43,7 @@ C419_1.py:3:5: C419 [*] Unnecessary list comprehension | ^^^^^^^^^^^^^^^^^^^^ C419 4 | sum([x.val for x in bar], 0) | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 1 1 | sum([x.val for x in bar]) @@ -63,7 +63,7 @@ C419_1.py:4:5: C419 [*] Unnecessary list comprehension 5 | 6 | # OK | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 1 1 | sum([x.val for x in bar]) @@ -89,7 +89,7 @@ C419_1.py:14:5: C419 [*] Unnecessary list comprehension 19 | dt.timedelta(), 20 | ) | - = help: Remove unnecessary list comprehension + = help: Remove unnecessary comprehension ℹ Unsafe fix 11 11 |