diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT027_1.py b/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT027_1.py index 708a582ad3d220..a591c19c971082 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT027_1.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT027_1.py @@ -10,3 +10,32 @@ def test_pytest_raises(self): def test_errors(self): with self.assertRaises(ValueError): raise ValueError + + def test_rewrite_references(self): + with self.assertRaises(ValueError) as e: + raise ValueError + + print(e.foo) + print(e.exception) + + def test_rewrite_references_multiple_items(self): + with self.assertRaises(ValueError) as e1, \ + self.assertRaises(ValueError) as e2: + raise ValueError + + print(e1.foo) + print(e1.exception) + + print(e2.foo) + print(e2.exception) + + def test_rewrite_references_multiple_items_nested(self): + with self.assertRaises(ValueError) as e1, \ + foo(self.assertRaises(ValueError)) as e2: + raise ValueError + + print(e1.foo) + print(e1.exception) + + print(e2.foo) + print(e2.exception) diff --git a/crates/ruff_linter/src/checkers/ast/analyze/bindings.rs b/crates/ruff_linter/src/checkers/ast/analyze/bindings.rs index 7873cadf5eb010..31097205abd4b8 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/bindings.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/bindings.rs @@ -4,7 +4,8 @@ use ruff_text_size::Ranged; use crate::checkers::ast::Checker; use crate::codes::Rule; use crate::rules::{ - flake8_import_conventions, flake8_pyi, flake8_type_checking, pyflakes, pylint, ruff, + flake8_import_conventions, flake8_pyi, flake8_pytest_style, flake8_type_checking, pyflakes, + pylint, ruff, }; /// Run lint rules over the [`Binding`]s. @@ -20,6 +21,7 @@ pub(crate) fn bindings(checker: &mut Checker) { Rule::UnusedVariable, Rule::UnquotedTypeAlias, Rule::UsedDummyVariable, + Rule::PytestUnittestRaisesAssertion, ]) { return; } @@ -100,5 +102,12 @@ pub(crate) fn bindings(checker: &mut Checker) { checker.diagnostics.push(diagnostic); } } + if checker.enabled(Rule::PytestUnittestRaisesAssertion) { + if let Some(diagnostic) = + flake8_pytest_style::rules::unittest_raises_assertion_binding(checker, binding) + { + checker.diagnostics.push(diagnostic); + } + } } } diff --git a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs index d6321665b85dd1..ce5425a2ab95a2 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs @@ -940,18 +940,10 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) { flake8_pytest_style::rules::parametrize(checker, call); } if checker.enabled(Rule::PytestUnittestAssertion) { - if let Some(diagnostic) = flake8_pytest_style::rules::unittest_assertion( - checker, expr, func, args, keywords, - ) { - checker.diagnostics.push(diagnostic); - } + flake8_pytest_style::rules::unittest_assertion(checker, expr, func, args, keywords); } if checker.enabled(Rule::PytestUnittestRaisesAssertion) { - if let Some(diagnostic) = - flake8_pytest_style::rules::unittest_raises_assertion(checker, call) - { - checker.diagnostics.push(diagnostic); - } + flake8_pytest_style::rules::unittest_raises_assertion_call(checker, call); } if checker.enabled(Rule::SubprocessPopenPreexecFn) { pylint::rules::subprocess_popen_preexec_fn(checker, call); diff --git a/crates/ruff_linter/src/rules/flake8_pytest_style/rules/assertion.rs b/crates/ruff_linter/src/rules/flake8_pytest_style/rules/assertion.rs index ba7874c1e0c188..fc078f516e1cc6 100644 --- a/crates/ruff_linter/src/rules/flake8_pytest_style/rules/assertion.rs +++ b/crates/ruff_linter/src/rules/flake8_pytest_style/rules/assertion.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::iter; use anyhow::Result; use anyhow::{bail, Context}; @@ -13,10 +14,11 @@ use ruff_python_ast::helpers::Truthiness; use ruff_python_ast::parenthesize::parenthesized_range; use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{ - self as ast, Arguments, BoolOp, ExceptHandler, Expr, Keyword, Stmt, UnaryOp, + self as ast, AnyNodeRef, Arguments, BoolOp, ExceptHandler, Expr, Keyword, Stmt, UnaryOp, }; use ruff_python_ast::{visitor, whitespace}; use ruff_python_codegen::Stylist; +use ruff_python_semantic::{Binding, BindingKind}; use ruff_source_file::LineRanges; use ruff_text_size::Ranged; @@ -266,47 +268,48 @@ fn check_assert_in_except(name: &str, body: &[Stmt]) -> Vec { /// PT009 pub(crate) fn unittest_assertion( - checker: &Checker, + checker: &mut Checker, expr: &Expr, func: &Expr, args: &[Expr], keywords: &[Keyword], -) -> Option { - match func { - Expr::Attribute(ast::ExprAttribute { attr, .. }) => { - if let Ok(unittest_assert) = UnittestAssert::try_from(attr.as_str()) { - let mut diagnostic = Diagnostic::new( - PytestUnittestAssertion { - assertion: unittest_assert.to_string(), - }, - func.range(), - ); - // We're converting an expression to a statement, so avoid applying the fix if - // the assertion is part of a larger expression. - if checker.semantic().current_statement().is_expr_stmt() - && checker.semantic().current_expression_parent().is_none() - && !checker.comment_ranges().intersects(expr.range()) - { - if let Ok(stmt) = unittest_assert.generate_assert(args, keywords) { - diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement( - checker.generator().stmt(&stmt), - parenthesized_range( - expr.into(), - checker.semantic().current_statement().into(), - checker.comment_ranges(), - checker.locator().contents(), - ) - .unwrap_or(expr.range()), - ))); - } - } - Some(diagnostic) - } else { - None - } +) { + let Expr::Attribute(ast::ExprAttribute { attr, .. }) = func else { + return; + }; + + let Ok(unittest_assert) = UnittestAssert::try_from(attr.as_str()) else { + return; + }; + + let mut diagnostic = Diagnostic::new( + PytestUnittestAssertion { + assertion: unittest_assert.to_string(), + }, + func.range(), + ); + + // We're converting an expression to a statement, so avoid applying the fix if + // the assertion is part of a larger expression. + if checker.semantic().current_statement().is_expr_stmt() + && checker.semantic().current_expression_parent().is_none() + && !checker.comment_ranges().intersects(expr.range()) + { + if let Ok(stmt) = unittest_assert.generate_assert(args, keywords) { + diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement( + checker.generator().stmt(&stmt), + parenthesized_range( + expr.into(), + checker.semantic().current_statement().into(), + checker.comment_ranges(), + checker.locator().contents(), + ) + .unwrap_or(expr.range()), + ))); } - _ => None, } + + checker.diagnostics.push(diagnostic); } /// ## What it does @@ -364,9 +367,96 @@ impl Violation for PytestUnittestRaisesAssertion { } /// PT027 -pub(crate) fn unittest_raises_assertion( +pub(crate) fn unittest_raises_assertion_call(checker: &mut Checker, call: &ast::ExprCall) { + // Bindings in `with` statements are handled by `unittest_raises_assertion_bindings`. + if let Stmt::With(ast::StmtWith { items, .. }) = checker.semantic().current_statement() { + let call_ref = AnyNodeRef::from(call); + + if items.iter().any(|item| { + AnyNodeRef::from(&item.context_expr).ptr_eq(call_ref) && item.optional_vars.is_some() + }) { + return; + } + } + + if let Some(diagnostic) = unittest_raises_assertion(call, vec![], checker) { + checker.diagnostics.push(diagnostic); + } +} + +/// PT027 +pub(crate) fn unittest_raises_assertion_binding( checker: &Checker, + binding: &Binding, +) -> Option { + if !matches!(binding.kind, BindingKind::WithItemVar) { + return None; + } + + let semantic = checker.semantic(); + + let Stmt::With(with) = binding.statement(semantic)? else { + return None; + }; + + let Expr::Call(call) = corresponding_context_expr(binding, with)? else { + return None; + }; + + let mut edits = vec![]; + + // Rewrite all references to `.exception` to `.value`: + // ```py + // # Before + // with self.assertRaises(Exception) as e: + // ... + // print(e.exception) + // + // # After + // with pytest.raises(Exception) as e: + // ... + // print(e.value) + // ``` + for reference_id in binding.references() { + let reference = semantic.reference(reference_id); + let node_id = reference.expression_id()?; + + let mut ancestors = semantic.expressions(node_id).skip(1); + + let Expr::Attribute(ast::ExprAttribute { attr, .. }) = ancestors.next()? else { + continue; + }; + + if attr.as_str() == "exception" { + edits.push(Edit::range_replacement("value".to_string(), attr.range)); + } + } + + unittest_raises_assertion(call, edits, checker) +} + +fn corresponding_context_expr<'a>(binding: &Binding, with: &'a ast::StmtWith) -> Option<&'a Expr> { + with.items.iter().find_map(|item| { + let Some(optional_var) = &item.optional_vars else { + return None; + }; + + let Expr::Name(name) = optional_var.as_ref() else { + return None; + }; + + if name.range == binding.range { + Some(&item.context_expr) + } else { + None + } + }) +} + +fn unittest_raises_assertion( call: &ast::ExprCall, + extra_edits: Vec, + checker: &Checker, ) -> Option { let Expr::Attribute(ast::ExprAttribute { attr, .. }) = call.func.as_ref() else { return None; @@ -385,19 +475,25 @@ pub(crate) fn unittest_raises_assertion( }, call.func.range(), ); + if !checker .comment_ranges() .has_comments(call, checker.source()) { if let Some(args) = to_pytest_raises_args(checker, attr.as_str(), &call.arguments) { diagnostic.try_set_fix(|| { - let (import_edit, binding) = checker.importer().get_or_import_symbol( + let (import_pytest_raises, binding) = checker.importer().get_or_import_symbol( &ImportRequest::import("pytest", "raises"), call.func.start(), checker.semantic(), )?; - let edit = Edit::range_replacement(format!("{binding}({args})"), call.range()); - Ok(Fix::unsafe_edits(import_edit, [edit])) + let replace_call = + Edit::range_replacement(format!("{binding}({args})"), call.range()); + + Ok(Fix::unsafe_edits( + import_pytest_raises, + iter::once(replace_call).chain(extra_edits), + )) }); } } diff --git a/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT027_1.snap b/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT027_1.snap index 6e4b4b463e52aa..3b79a1afa313e7 100644 --- a/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT027_1.snap +++ b/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT027_1.snap @@ -1,6 +1,5 @@ --- source: crates/ruff_linter/src/rules/flake8_pytest_style/mod.rs -snapshot_kind: text --- PT027_1.py:11:14: PT027 [*] Use `pytest.raises` instead of unittest-style `assertRaises` | @@ -18,3 +17,129 @@ PT027_1.py:11:14: PT027 [*] Use `pytest.raises` instead of unittest-style `asser 11 |- with self.assertRaises(ValueError): 11 |+ with pytest.raises(ValueError): 12 12 | raise ValueError +13 13 | +14 14 | def test_rewrite_references(self): + +PT027_1.py:15:14: PT027 [*] Use `pytest.raises` instead of unittest-style `assertRaises` + | +14 | def test_rewrite_references(self): +15 | with self.assertRaises(ValueError) as e: + | ^^^^^^^^^^^^^^^^^ PT027 +16 | raise ValueError + | + = help: Replace `assertRaises` with `pytest.raises` + +ℹ Unsafe fix +12 12 | raise ValueError +13 13 | +14 14 | def test_rewrite_references(self): +15 |- with self.assertRaises(ValueError) as e: + 15 |+ with pytest.raises(ValueError) as e: +16 16 | raise ValueError +17 17 | +18 18 | print(e.foo) +19 |- print(e.exception) + 19 |+ print(e.value) +20 20 | +21 21 | def test_rewrite_references_multiple_items(self): +22 22 | with self.assertRaises(ValueError) as e1, \ + +PT027_1.py:22:14: PT027 [*] Use `pytest.raises` instead of unittest-style `assertRaises` + | +21 | def test_rewrite_references_multiple_items(self): +22 | with self.assertRaises(ValueError) as e1, \ + | ^^^^^^^^^^^^^^^^^ PT027 +23 | self.assertRaises(ValueError) as e2: +24 | raise ValueError + | + = help: Replace `assertRaises` with `pytest.raises` + +ℹ Unsafe fix +19 19 | print(e.exception) +20 20 | +21 21 | def test_rewrite_references_multiple_items(self): +22 |- with self.assertRaises(ValueError) as e1, \ + 22 |+ with pytest.raises(ValueError) as e1, \ +23 23 | self.assertRaises(ValueError) as e2: +24 24 | raise ValueError +25 25 | +26 26 | print(e1.foo) +27 |- print(e1.exception) + 27 |+ print(e1.value) +28 28 | +29 29 | print(e2.foo) +30 30 | print(e2.exception) + +PT027_1.py:23:13: PT027 [*] Use `pytest.raises` instead of unittest-style `assertRaises` + | +21 | def test_rewrite_references_multiple_items(self): +22 | with self.assertRaises(ValueError) as e1, \ +23 | self.assertRaises(ValueError) as e2: + | ^^^^^^^^^^^^^^^^^ PT027 +24 | raise ValueError + | + = help: Replace `assertRaises` with `pytest.raises` + +ℹ Unsafe fix +20 20 | +21 21 | def test_rewrite_references_multiple_items(self): +22 22 | with self.assertRaises(ValueError) as e1, \ +23 |- self.assertRaises(ValueError) as e2: + 23 |+ pytest.raises(ValueError) as e2: +24 24 | raise ValueError +25 25 | +26 26 | print(e1.foo) +27 27 | print(e1.exception) +28 28 | +29 29 | print(e2.foo) +30 |- print(e2.exception) + 30 |+ print(e2.value) +31 31 | +32 32 | def test_rewrite_references_multiple_items_nested(self): +33 33 | with self.assertRaises(ValueError) as e1, \ + +PT027_1.py:33:14: PT027 [*] Use `pytest.raises` instead of unittest-style `assertRaises` + | +32 | def test_rewrite_references_multiple_items_nested(self): +33 | with self.assertRaises(ValueError) as e1, \ + | ^^^^^^^^^^^^^^^^^ PT027 +34 | foo(self.assertRaises(ValueError)) as e2: +35 | raise ValueError + | + = help: Replace `assertRaises` with `pytest.raises` + +ℹ Unsafe fix +30 30 | print(e2.exception) +31 31 | +32 32 | def test_rewrite_references_multiple_items_nested(self): +33 |- with self.assertRaises(ValueError) as e1, \ + 33 |+ with pytest.raises(ValueError) as e1, \ +34 34 | foo(self.assertRaises(ValueError)) as e2: +35 35 | raise ValueError +36 36 | +37 37 | print(e1.foo) +38 |- print(e1.exception) + 38 |+ print(e1.value) +39 39 | +40 40 | print(e2.foo) +41 41 | print(e2.exception) + +PT027_1.py:34:17: PT027 [*] Use `pytest.raises` instead of unittest-style `assertRaises` + | +32 | def test_rewrite_references_multiple_items_nested(self): +33 | with self.assertRaises(ValueError) as e1, \ +34 | foo(self.assertRaises(ValueError)) as e2: + | ^^^^^^^^^^^^^^^^^ PT027 +35 | raise ValueError + | + = help: Replace `assertRaises` with `pytest.raises` + +ℹ Unsafe fix +31 31 | +32 32 | def test_rewrite_references_multiple_items_nested(self): +33 33 | with self.assertRaises(ValueError) as e1, \ +34 |- foo(self.assertRaises(ValueError)) as e2: + 34 |+ foo(pytest.raises(ValueError)) as e2: +35 35 | raise ValueError +36 36 | +37 37 | print(e1.foo)