Skip to content

Commit

Permalink
Deduplicate edits when quoting annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Dec 14, 2023
1 parent c014622 commit 6c8d155
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ def f():

def func(value: DataFrame):
...


def f():
from pandas import DataFrame, Series

def baz() -> DataFrame | Series:
...
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;

use anyhow::Result;
use itertools::Itertools;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{Diagnostic, Fix, FixAvailability, Violation};
Expand Down Expand Up @@ -262,7 +263,7 @@ pub(crate) fn runtime_import_in_type_checking_block(

/// Generate a [`Fix`] to quote runtime usages for imports in a type-checking block.
fn quote_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> Result<Fix> {
let mut quote_reference_edits = imports
let quote_reference_edits = imports
.iter()
.flat_map(|ImportBinding { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
Expand All @@ -280,14 +281,12 @@ fn quote_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding])
})
})
.collect::<Result<Vec<_>>>()?;
let quote_reference_edit = quote_reference_edits
.pop()
.expect("Expected at least one reference");
Ok(
Fix::unsafe_edits(quote_reference_edit, quote_reference_edits).isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
)),
)

let mut rest = quote_reference_edits.into_iter().dedup();
let head = rest.next().expect("Expected at least one reference");
Ok(Fix::unsafe_edits(head, rest).isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
)))
}

/// Generate a [`Fix`] to remove runtime imports from a type-checking block.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;

use anyhow::Result;
use itertools::Itertools;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{Diagnostic, DiagnosticKind, Fix, FixAvailability, Violation};
Expand Down Expand Up @@ -506,7 +507,7 @@ fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) ->
add_import_edit
.into_edits()
.into_iter()
.chain(quote_reference_edits),
.chain(quote_reference_edits.into_iter().dedup()),
)
.isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ quote.py:64:28: TCH004 [*] Quote references to `pandas.DataFrame`. Import is in
66 |- def func(value: DataFrame):
66 |+ def func(value: "DataFrame"):
67 67 | ...
68 68 |
69 69 |


Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,60 @@ quote.py:54:24: TCH002 Move third-party import `pandas.DataFrame` into a type-ch
|
= help: Move into type-checking block

quote.py:71:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block
|
70 | def f():
71 | from pandas import DataFrame, Series
| ^^^^^^^^^ TCH002
72 |
73 | def baz() -> DataFrame | Series:
|
= help: Move into type-checking block

Unsafe fix
1 |+from typing import TYPE_CHECKING
2 |+
3 |+if TYPE_CHECKING:
4 |+ from pandas import DataFrame, Series
1 5 | def f():
2 6 | from pandas import DataFrame
3 7 |
--------------------------------------------------------------------------------
68 72 |
69 73 |
70 74 | def f():
71 |- from pandas import DataFrame, Series
72 75 |
73 |- def baz() -> DataFrame | Series:
76 |+ def baz() -> "DataFrame | Series":
74 77 | ...

quote.py:71:35: TCH002 [*] Move third-party import `pandas.Series` into a type-checking block
|
70 | def f():
71 | from pandas import DataFrame, Series
| ^^^^^^ TCH002
72 |
73 | def baz() -> DataFrame | Series:
|
= help: Move into type-checking block

Unsafe fix
1 |+from typing import TYPE_CHECKING
2 |+
3 |+if TYPE_CHECKING:
4 |+ from pandas import DataFrame, Series
1 5 | def f():
2 6 | from pandas import DataFrame
3 7 |
--------------------------------------------------------------------------------
68 72 |
69 73 |
70 74 | def f():
71 |- from pandas import DataFrame, Series
72 75 |
73 |- def baz() -> DataFrame | Series:
76 |+ def baz() -> "DataFrame | Series":
74 77 | ...


0 comments on commit 6c8d155

Please sign in to comment.