Skip to content

Commit

Permalink
Re-use typing_extensions.TYPE_CHECKING if available
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Nov 9, 2023
1 parent 4419953 commit bbc3d3b
Show file tree
Hide file tree
Showing 17 changed files with 198 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,14 @@ def f():
from module import Member

x: Member = 1


def f():
from typing_extensions import TYPE_CHECKING

from pandas import y

if TYPE_CHECKING:
_type = x
elif True:
_type = y
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from typing_extensions import TYPE_CHECKING

if TYPE_CHECKING:
from pandas import DataFrame


def example() -> DataFrame:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from typing_extensions import TYPE_CHECKING

if TYPE_CHECKING:
from pandas import DataFrame


def example() -> DataFrame:
x = DataFrame()
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ class Test:

if 0:
x: List


from typing_extensions import TYPE_CHECKING

if TYPE_CHECKING:
pass # TCH005
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from typing_extensions import Self


def func():
from pandas import DataFrame

df: DataFrame
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

import typing_extensions


def func():
from pandas import DataFrame

df: DataFrame
30 changes: 25 additions & 5 deletions crates/ruff_linter/src/importer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,7 @@ impl<'a> Importer<'a> {
)?;

// Import the `TYPE_CHECKING` symbol from the typing module.
let (type_checking_edit, type_checking) = self.get_or_import_symbol(
&ImportRequest::import_from("typing", "TYPE_CHECKING"),
at,
semantic,
)?;
let (type_checking_edit, type_checking) = self.get_or_import_type_checking(at, semantic)?;

// Add the import to a `TYPE_CHECKING` block.
let add_import_edit = if let Some(block) = self.preceding_type_checking_block(at) {
Expand All @@ -161,6 +157,30 @@ impl<'a> Importer<'a> {
})
}

/// Generate an [`Edit`] to reference `typing.TYPE_CHECKING`. Returns the [`Edit`] necessary to
/// make the symbol available in the current scope along with the bound name of the symbol.
fn get_or_import_type_checking(
&self,
at: TextSize,
semantic: &SemanticModel,
) -> Result<(Edit, String), ResolutionError> {
for module in semantic.typing_modules() {
if let Some((edit, name)) = self.get_symbol(
&ImportRequest::import_from(module, "TYPE_CHECKING"),
at,
semantic,
)? {
return Ok((edit, name));
}
}

self.get_or_import_symbol(
&ImportRequest::import_from("typing", "TYPE_CHECKING"),
at,
semantic,
)
}

/// Generate an [`Edit`] to reference the given symbol. Returns the [`Edit`] necessary to make
/// the symbol available in the current scope along with the bound name of the symbol.
///
Expand Down
4 changes: 4 additions & 0 deletions crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ mod tests {
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_13.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_14.pyi"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_15.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_16.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_17.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_2.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_3.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_4.py"))]
Expand All @@ -36,6 +38,8 @@ mod tests {
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("snapshot.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("TCH002.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("strict.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("typing_modules_1.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("typing_modules_2.py"))]
fn rules(rule_code: Rule, path: &Path) -> Result<()> {
let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy());
let diagnostics = test_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,19 @@ TCH005.py:22:9: TCH005 [*] Found empty type-checking block
24 22 |
25 23 |

TCH005.py:45:5: TCH005 [*] Found empty type-checking block
|
44 | if TYPE_CHECKING:
45 | pass # TCH005
| ^^^^ TCH005
|
= help: Delete empty type-checking block

Safe fix
41 41 |
42 42 | from typing_extensions import TYPE_CHECKING
43 43 |
44 |-if TYPE_CHECKING:
45 |- pass # TCH005


Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
---

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
---
TCH004_17.py:6:24: TCH004 [*] Move import `pandas.DataFrame` out of type-checking block. Import is used for more than type hinting.
|
5 | if TYPE_CHECKING:
6 | from pandas import DataFrame
| ^^^^^^^^^ TCH004
|
= help: Move out of type-checking block

Unsafe fix
1 1 | from __future__ import annotations
2 2 |
3 3 | from typing_extensions import TYPE_CHECKING
4 |+from pandas import DataFrame
4 5 |
5 6 | if TYPE_CHECKING:
6 |- from pandas import DataFrame
7 |+ pass
7 8 |
8 9 |
9 10 | def example() -> DataFrame:


Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,6 @@ TCH002.py:172:24: TCH002 [*] Move third-party import `module.Member` into a type
172 |- from module import Member
173 176 |
174 177 | x: Member = 1
175 178 |


Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
---
typing_modules_1.py:7:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block
|
6 | def func():
7 | from pandas import DataFrame
| ^^^^^^^^^ TCH002
8 |
9 | df: DataFrame
|
= help: Move into type-checking block

Unsafe fix
1 1 | from __future__ import annotations
2 2 |
3 3 | from typing_extensions import Self
4 |+from typing import TYPE_CHECKING
5 |+
6 |+if TYPE_CHECKING:
7 |+ from pandas import DataFrame
4 8 |
5 9 |
6 10 | def func():
7 |- from pandas import DataFrame
8 11 |
9 12 | df: DataFrame


Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
---
typing_modules_2.py:7:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block
|
6 | def func():
7 | from pandas import DataFrame
| ^^^^^^^^^ TCH002
8 |
9 | df: DataFrame
|
= help: Move into type-checking block

Unsafe fix
2 2 |
3 3 | import typing_extensions
4 4 |
5 |+if typing_extensions.TYPE_CHECKING:
6 |+ from pandas import DataFrame
7 |+
5 8 |
6 9 | def func():
7 |- from pandas import DataFrame
8 10 |
9 11 | df: DataFrame


7 changes: 1 addition & 6 deletions crates/ruff_python_semantic/src/analyze/typing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,7 @@ pub fn is_type_checking_block(stmt: &ast::StmtIf, semantic: &SemanticModel) -> b
}

// Ex) `if typing.TYPE_CHECKING:`
if semantic.resolve_call_path(test).is_some_and(|call_path| {
matches!(
call_path.as_slice(),
["typing", "TYPE_CHECKING"] | ["typing_extensions" | "TYPE_CHECKING"]
)
}) {
if semantic.match_typing_expr(test, "TYPE_CHECKING") {
return true;
}

Expand Down
23 changes: 12 additions & 11 deletions crates/ruff_python_semantic/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,13 @@ impl<'a> SemanticModel<'a> {

/// Return `true` if the call path is a reference to `typing.${target}`.
pub fn match_typing_call_path(&self, call_path: &CallPath, target: &str) -> bool {
if call_path.as_slice() == ["typing", target] {
if matches!(
call_path.as_slice(),
["typing" | "_typeshed" | "typing_extensions", target]
) {
return true;
}

if call_path.as_slice() == ["_typeshed", target] {
return true;
}

if is_typing_extension(target) {
if call_path.as_slice() == ["typing_extensions", target] {
return true;
}
}

if self.typing_modules.iter().any(|module| {
let mut module: CallPath = from_unqualified_name(module);
module.push(target);
Expand All @@ -200,6 +193,14 @@ impl<'a> SemanticModel<'a> {
false
}

/// Return an iterator over the set of `typing` modules allowed in the semantic model.
pub fn typing_modules(&self) -> impl Iterator<Item = &str> {
["typing", "_typeshed", "typing_extensions"]
.iter()
.copied()
.chain(self.typing_modules.iter().map(String::as_str))
}

/// Create a new [`Binding`] for a builtin.
pub fn push_builtin(&mut self) -> BindingId {
self.bindings.push(Binding {
Expand Down
60 changes: 0 additions & 60 deletions crates/ruff_python_stdlib/src/typing.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,3 @@
/// Returns `true` if a name is a member of Python's `typing_extensions` module.
///
/// See: <https://pypi.org/project/typing-extensions/>
pub fn is_typing_extension(member: &str) -> bool {
matches!(
member,
"Annotated"
| "Any"
| "AsyncContextManager"
| "AsyncGenerator"
| "AsyncIterable"
| "AsyncIterator"
| "Awaitable"
| "ChainMap"
| "ClassVar"
| "Concatenate"
| "ContextManager"
| "Coroutine"
| "Counter"
| "DefaultDict"
| "Deque"
| "Final"
| "Literal"
| "LiteralString"
| "NamedTuple"
| "Never"
| "NewType"
| "NotRequired"
| "OrderedDict"
| "ParamSpec"
| "ParamSpecArgs"
| "ParamSpecKwargs"
| "Protocol"
| "Required"
| "Self"
| "TYPE_CHECKING"
| "Text"
| "Type"
| "TypeAlias"
| "TypeGuard"
| "TypeVar"
| "TypeVarTuple"
| "TypedDict"
| "Unpack"
| "assert_never"
| "assert_type"
| "clear_overloads"
| "final"
| "get_type_hints"
| "get_args"
| "get_origin"
| "get_overloads"
| "is_typeddict"
| "overload"
| "override"
| "reveal_type"
| "runtime_checkable"
)
}

/// Returns `true` if a call path is a generic from the Python standard library (e.g. `list`, which
/// can be used as `list[int]`).
///
Expand Down

0 comments on commit bbc3d3b

Please sign in to comment.