Skip to content

Commit

Permalink
[flake8-pyi] Implement autofix for redundant-numeric-union (`PYI0…
Browse files Browse the repository at this point in the history
…41`) (astral-sh#14273)

## Summary

This PR adds autofix for `redundant-numeric-union` (`PYI041`)

There are some comments below to explain the reasoning behind some
choices that might help review.

<!-- What's the purpose of the change? What does it do, and why? -->

Resolves part of astral-sh#14185.

## Test Plan

<!-- How was it tested? -->

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
Co-authored-by: Charlie Marsh <charlie.r.marsh@gmail.com>
  • Loading branch information
3 people authored and dylwil3 committed Nov 17, 2024
1 parent 4b7fe4b commit 5c797e6
Show file tree
Hide file tree
Showing 7 changed files with 424 additions and 72 deletions.
29 changes: 29 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI041.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,38 @@ async def f4(**kwargs: int | int | float) -> None:
...


def f5(
arg: Union[ # comment
float, # another
complex, int]
) -> None:
...

def f6(
arg: (
int | # comment
float | # another
complex
)
) -> None:
...


class Foo:
def good(self, arg: int) -> None:
...

def bad(self, arg: int | float | complex) -> None:
...

def bad2(self, arg: int | Union[float, complex]) -> None:
...

def bad3(self, arg: Union[Union[float, complex], int]) -> None:
...

def bad4(self, arg: Union[float | complex, int]) -> None:
...

def bad5(self, arg: int | (float | complex)) -> None:
...
21 changes: 21 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI041.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,29 @@ def f3(arg1: int, *args: Union[int | int | float]) -> None: ... # PYI041

async def f4(**kwargs: int | int | float) -> None: ... # PYI041

def f5(
arg: Union[ # comment
float, # another
complex, int]
) -> None: ... # PYI041

def f6(
arg: (
int | # comment
float | # another
complex
)
) -> None: ... # PYI041

class Foo:
def good(self, arg: int) -> None: ...

def bad(self, arg: int | float | complex) -> None: ... # PYI041

def bad2(self, arg: int | Union[float, complex]) -> None: ... # PYI041

def bad3(self, arg: Union[Union[float, complex], int]) -> None: ... # PYI041

def bad4(self, arg: Union[float | complex, int]) -> None: ... # PYI041

def bad5(self, arg: int | (float | complex)) -> None: ... # PYI041
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::checkers::ast::Checker;
/// ## Fix safety
/// This rule's fix is marked as safe, unless the type annotation contains comments.
///
/// Note that the fix will flatten nested literals into a single top-level
/// literal.
/// Note that while the fix may flatten nested literals into a single top-level literal,
/// the semantics of the annotation will remain unchanged.
///
/// ## References
/// - [Python documentation: `typing.Literal`](https://docs.python.org/3/library/typing.html#typing.Literal)
Expand Down
260 changes: 217 additions & 43 deletions crates/ruff_linter/src/rules/flake8_pyi/rules/redundant_numeric_union.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use ruff_diagnostics::{Diagnostic, Violation};
use bitflags::bitflags;

use anyhow::Result;

use ruff_diagnostics::{Applicability, Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::{AnyParameterRef, Expr, Parameters};
use ruff_python_ast::{
name::Name, AnyParameterRef, Expr, ExprBinOp, ExprContext, ExprName, ExprSubscript, ExprTuple,
Operator, Parameters,
};
use ruff_python_semantic::analyze::typing::traverse_union;
use ruff_text_size::Ranged;
use ruff_text_size::{Ranged, TextRange};

use crate::checkers::ast::Checker;
use crate::{checkers::ast::Checker, importer::ImportRequest};

/// ## What it does
/// Checks for parameter annotations that contain redundant unions between
Expand Down Expand Up @@ -37,6 +44,12 @@ use crate::checkers::ast::Checker;
/// def foo(x: float | str) -> None: ...
/// ```
///
/// ## Fix safety
/// This rule's fix is marked as safe, unless the type annotation contains comments.
///
/// Note that while the fix may flatten nested unions into a single top-level union,
/// the semantics of the annotation will remain unchanged.
///
/// ## References
/// - [Python documentation: The numeric tower](https://docs.python.org/3/library/numbers.html#the-numeric-tower)
/// - [PEP 484: The numeric tower](https://peps.python.org/pep-0484/#the-numeric-tower)
Expand All @@ -48,15 +61,23 @@ pub struct RedundantNumericUnion {
}

impl Violation for RedundantNumericUnion {
// Always fixable, but currently under preview.
const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes;

#[derive_message_formats]
fn message(&self) -> String {
let (subtype, supertype) = match self.redundancy {
Redundancy::IntFloatComplex => ("int | float", "complex"),
Redundancy::FloatComplex => ("float", "complex"),
Redundancy::IntComplex => ("int", "complex"),
Redundancy::IntFloat => ("int", "float"),
};
format!("Use `{supertype}` instead of `{subtype} | {supertype}`")
}

fn fix_title(&self) -> Option<String> {
Some("Remove redundant type".to_string())
}
}

/// PYI041
Expand All @@ -66,57 +87,210 @@ pub(crate) fn redundant_numeric_union(checker: &mut Checker, parameters: &Parame
}
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum Redundancy {
FloatComplex,
IntComplex,
IntFloat,
}

fn check_annotation(checker: &mut Checker, annotation: &Expr) {
let mut has_float = false;
let mut has_complex = false;
let mut has_int = false;
fn check_annotation<'a>(checker: &mut Checker, annotation: &'a Expr) {
let mut numeric_flags = NumericFlags::empty();

let mut find_numeric_type = |expr: &Expr, _parent: &Expr| {
let Some(builtin_type) = checker.semantic().resolve_builtin_symbol(expr) else {
return;
};

match builtin_type {
"int" => has_int = true,
"float" => has_float = true,
"complex" => has_complex = true,
_ => {}
}
numeric_flags.seen_builtin_type(builtin_type);
};

// Traverse the union, and remember which numeric types are found.
traverse_union(&mut find_numeric_type, checker.semantic(), annotation);

if has_complex {
if has_float {
checker.diagnostics.push(Diagnostic::new(
RedundantNumericUnion {
redundancy: Redundancy::FloatComplex,
},
annotation.range(),
));
let Some(redundancy) = Redundancy::from_numeric_flags(numeric_flags) else {
return;
};

// Traverse the union a second time to construct the fix.
let mut necessary_nodes: Vec<&Expr> = Vec::new();

let mut union_type = UnionKind::TypingUnion;
let mut remove_numeric_type = |expr: &'a Expr, parent: &'a Expr| {
let Some(builtin_type) = checker.semantic().resolve_builtin_symbol(expr) else {
// Keep type annotations that are not numeric.
necessary_nodes.push(expr);
return;
};

if matches!(parent, Expr::BinOp(_)) {
union_type = UnionKind::PEP604;
}

if has_int {
checker.diagnostics.push(Diagnostic::new(
RedundantNumericUnion {
redundancy: Redundancy::IntComplex,
},
annotation.range(),
));
// `int` is always dropped, since `float` or `complex` must be present.
// `float` is only dropped if `complex`` is present.
if (builtin_type == "float" && !numeric_flags.contains(NumericFlags::COMPLEX))
|| (builtin_type != "float" && builtin_type != "int")
{
necessary_nodes.push(expr);
}
};

// Traverse the union a second time to construct a [`Fix`].
traverse_union(&mut remove_numeric_type, checker.semantic(), annotation);

let mut diagnostic = Diagnostic::new(RedundantNumericUnion { redundancy }, annotation.range());
if checker.settings.preview.is_enabled() {
// Mark [`Fix`] as unsafe when comments are in range.
let applicability = if checker.comment_ranges().intersects(annotation.range()) {
Applicability::Unsafe
} else {
Applicability::Safe
};

// Generate the flattened fix once.
let fix = if let &[edit_expr] = necessary_nodes.as_slice() {
// Generate a [`Fix`] for a single type expression, e.g. `int`.
Fix::applicable_edit(
Edit::range_replacement(checker.generator().expr(edit_expr), annotation.range()),
applicability,
)
} else {
match union_type {
UnionKind::PEP604 => {
generate_pep604_fix(checker, necessary_nodes, annotation, applicability)
}
UnionKind::TypingUnion => {
generate_union_fix(checker, necessary_nodes, annotation, applicability)
.ok()
.unwrap()
}
}
};
diagnostic.set_fix(fix);
};

checker.diagnostics.push(diagnostic);
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum Redundancy {
IntFloatComplex,
FloatComplex,
IntComplex,
IntFloat,
}

impl Redundancy {
pub(super) fn from_numeric_flags(numeric_flags: NumericFlags) -> Option<Self> {
if numeric_flags == NumericFlags::INT | NumericFlags::FLOAT | NumericFlags::COMPLEX {
Some(Self::IntFloatComplex)
} else if numeric_flags == NumericFlags::FLOAT | NumericFlags::COMPLEX {
Some(Self::FloatComplex)
} else if numeric_flags == NumericFlags::INT | NumericFlags::COMPLEX {
Some(Self::IntComplex)
} else if numeric_flags == NumericFlags::FLOAT | NumericFlags::INT {
Some(Self::IntFloat)
} else {
None
}
} else if has_float && has_int {
checker.diagnostics.push(Diagnostic::new(
RedundantNumericUnion {
redundancy: Redundancy::IntFloat,
},
annotation.range(),
));
}
}

bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(super) struct NumericFlags: u8 {
/// `int`
const INT = 1 << 0;
/// `float`
const FLOAT = 1 << 1;
/// `complex`
const COMPLEX = 1 << 2;
}
}

impl NumericFlags {
pub(super) fn seen_builtin_type(&mut self, name: &str) {
let flag: NumericFlags = match name {
"int" => NumericFlags::INT,
"float" => NumericFlags::FLOAT,
"complex" => NumericFlags::COMPLEX,
_ => {
return;
}
};
self.insert(flag);
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum UnionKind {
/// E.g., `typing.Union[int, str]`
TypingUnion,
/// E.g., `int | str`
PEP604,
}

// Generate a [`Fix`] for two or more type expressions, e.g. `int | float | complex`.
fn generate_pep604_fix(
checker: &Checker,
nodes: Vec<&Expr>,
annotation: &Expr,
applicability: Applicability,
) -> Fix {
debug_assert!(nodes.len() >= 2, "At least two nodes required");

let new_expr = nodes
.into_iter()
.fold(None, |acc: Option<Expr>, right: &Expr| {
if let Some(left) = acc {
Some(Expr::BinOp(ExprBinOp {
left: Box::new(left),
op: Operator::BitOr,
right: Box::new(right.clone()),
range: TextRange::default(),
}))
} else {
Some(right.clone())
}
})
.unwrap();

Fix::applicable_edit(
Edit::range_replacement(checker.generator().expr(&new_expr), annotation.range()),
applicability,
)
}

// Generate a [`Fix`] for two or more type expresisons, e.g. `typing.Union[int, float, complex]`.
fn generate_union_fix(
checker: &Checker,
nodes: Vec<&Expr>,
annotation: &Expr,
applicability: Applicability,
) -> Result<Fix> {
debug_assert!(nodes.len() >= 2, "At least two nodes required");

// Request `typing.Union`
let (import_edit, binding) = checker.importer().get_or_import_symbol(
&ImportRequest::import_from("typing", "Union"),
annotation.start(),
checker.semantic(),
)?;

// Construct the expression as `Subscript[typing.Union, Tuple[expr, [expr, ...]]]`
let new_expr = Expr::Subscript(ExprSubscript {
range: TextRange::default(),
value: Box::new(Expr::Name(ExprName {
id: Name::new(binding),
ctx: ExprContext::Store,
range: TextRange::default(),
})),
slice: Box::new(Expr::Tuple(ExprTuple {
elts: nodes.into_iter().cloned().collect(),
range: TextRange::default(),
ctx: ExprContext::Load,
parenthesized: false,
})),
ctx: ExprContext::Load,
});

Ok(Fix::applicable_edits(
Edit::range_replacement(checker.generator().expr(&new_expr), annotation.range()),
[import_edit],
applicability,
))
}
Loading

0 comments on commit 5c797e6

Please sign in to comment.