Skip to content

Commit

Permalink
Add support for @functools.singledispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Dec 1, 2023
1 parent b2638c6 commit 1a4c324
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Test module."""
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING

from numpy import asarray
from numpy.typing import ArrayLike
from scipy.sparse import spmatrix
from pandas import DataFrame

if TYPE_CHECKING:
from numpy import ndarray


@singledispatch
def to_array_or_mat(a: ArrayLike | spmatrix) -> ndarray | spmatrix:
"""Convert arg to array or leaves it as sparse matrix."""
msg = f"Unhandled type {type(a)}"
raise NotImplementedError(msg)


@to_array_or_mat.register
def _(a: ArrayLike) -> ndarray:
return asarray(a)


@to_array_or_mat.register
def _(a: spmatrix) -> spmatrix:
return a


def _(a: DataFrame) -> DataFrame:
return a
45 changes: 27 additions & 18 deletions crates/ruff_linter/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,13 @@ where
// are enabled.
let runtime_annotation = !self.semantic.future_annotations();

// The first parameter may be a single dispatch.
let mut singledispatch =
flake8_type_checking::helpers::is_singledispatch_implementation(
function_def,
self.semantic(),
);

self.semantic.push_scope(ScopeKind::Type);

if let Some(type_params) = type_params {
Expand All @@ -505,7 +512,7 @@ where
.chain(&parameters.kwonlyargs)
{
if let Some(expr) = &parameter_with_default.parameter.annotation {
if runtime_annotation {
if runtime_annotation || singledispatch {
self.visit_runtime_annotation(expr);
} else {
self.visit_annotation(expr);
Expand All @@ -514,6 +521,7 @@ where
if let Some(expr) = &parameter_with_default.default {
self.visit_expr(expr);
}
singledispatch = false;
}
if let Some(arg) = &parameters.vararg {
if let Some(expr) = &arg.annotation {
Expand Down Expand Up @@ -670,23 +678,24 @@ where
// available at runtime.
// See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements
let runtime_annotation = if self.semantic.future_annotations() {
if self.semantic.current_scope().kind.is_class() {
let baseclasses = &self
.settings
.flake8_type_checking
.runtime_evaluated_base_classes;
let decorators = &self
.settings
.flake8_type_checking
.runtime_evaluated_decorators;
flake8_type_checking::helpers::runtime_evaluated(
baseclasses,
decorators,
&self.semantic,
)
} else {
false
}
self.semantic
.current_scope()
.kind
.as_class()
.is_some_and(|class_def| {
flake8_type_checking::helpers::runtime_evaluated_class(
class_def,
&self
.settings
.flake8_type_checking
.runtime_evaluated_base_classes,
&self
.settings
.flake8_type_checking
.runtime_evaluated_decorators,
&self.semantic,
)
})
} else {
matches!(
self.semantic.current_scope().kind,
Expand Down
116 changes: 94 additions & 22 deletions crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ruff_python_ast::call_path::from_qualified_name;
use ruff_python_ast::helpers::{map_callable, map_subscript};
use ruff_python_ast::{self as ast};
use ruff_python_semantic::{Binding, BindingId, BindingKind, ScopeKind, SemanticModel};
use ruff_python_ast::{self as ast, Expr};
use ruff_python_semantic::{Binding, BindingId, BindingKind, SemanticModel};
use rustc_hash::FxHashSet;

pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticModel) -> bool {
Expand All @@ -18,25 +18,26 @@ pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticMode
}
}

pub(crate) fn runtime_evaluated(
pub(crate) fn runtime_evaluated_class(
class_def: &ast::StmtClassDef,
base_classes: &[String],
decorators: &[String],
semantic: &SemanticModel,
) -> bool {
if !base_classes.is_empty() {
if runtime_evaluated_base_class(base_classes, semantic) {
return true;
}
if runtime_evaluated_base_class(class_def, base_classes, semantic) {
return true;
}
if !decorators.is_empty() {
if runtime_evaluated_decorators(decorators, semantic) {
return true;
}
if runtime_evaluated_decorators(class_def, decorators, semantic) {
return true;
}
false
}

fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool {
fn runtime_evaluated_base_class(
class_def: &ast::StmtClassDef,
base_classes: &[String],
semantic: &SemanticModel,
) -> bool {
fn inner(
class_def: &ast::StmtClassDef,
base_classes: &[String],
Expand Down Expand Up @@ -78,19 +79,21 @@ fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticMode
})
}

semantic
.current_scope()
.kind
.as_class()
.is_some_and(|class_def| {
inner(class_def, base_classes, semantic, &mut FxHashSet::default())
})
if base_classes.is_empty() {
return false;
}

inner(class_def, base_classes, semantic, &mut FxHashSet::default())
}

fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel) -> bool {
let ScopeKind::Class(class_def) = &semantic.current_scope().kind else {
fn runtime_evaluated_decorators(
class_def: &ast::StmtClassDef,
decorators: &[String],
semantic: &SemanticModel,
) -> bool {
if decorators.is_empty() {
return false;
};
}

class_def.decorator_list.iter().any(|decorator| {
semantic
Expand All @@ -102,3 +105,72 @@ fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel)
})
})
}

/// Returns `true` if a function is registered as a `singledispatch` interface.
///
/// For example, `fun` below is a `singledispatch` interface:
/// ```python
/// from functools import singledispatch
///
/// @singledispatch
/// def fun(arg, verbose=False):
/// ...
/// ```
pub(crate) fn is_singledispatch_interface(
function_def: &ast::StmtFunctionDef,
semantic: &SemanticModel,
) -> bool {
function_def.decorator_list.iter().any(|decorator| {
semantic
.resolve_call_path(&decorator.expression)
.is_some_and(|call_path| {
matches!(call_path.as_slice(), ["functools", "singledispatch"])
})
})
}

/// Returns `true` if a function is registered as a `singledispatch` implementation.
///
/// For example, `_` below is a `singledispatch` implementation:
/// For example:
/// ```python
/// from functools import singledispatch
///
/// @singledispatch
/// def fun(arg, verbose=False):
/// ...
///
/// @fun.register
/// def _(arg: int, verbose=False):
/// ...
/// ```
pub(crate) fn is_singledispatch_implementation(
function_def: &ast::StmtFunctionDef,
semantic: &SemanticModel,
) -> bool {
function_def.decorator_list.iter().any(|decorator| {
let Expr::Attribute(attribute) = &decorator.expression else {
return false;
};

if attribute.attr.as_str() != "register" {
return false;
};

let Some(id) = semantic.lookup_attribute(attribute.value.as_ref()) else {
return false;
};

let binding = semantic.binding(id);
let Some(function_def) = binding
.kind
.as_function_definition()
.map(|id| &semantic.scopes[*id])
.and_then(|scope| scope.kind.as_function())
else {
return false;
};

is_singledispatch_interface(function_def, semantic)
})
}
1 change: 1 addition & 0 deletions crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ mod tests {
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("TCH003.py"))]
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("snapshot.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("TCH002.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("singledispatch.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("strict.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("typing_modules_1.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("typing_modules_2.py"))]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
---
singledispatch.py:10:20: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block
|
8 | from numpy.typing import ArrayLike
9 | from scipy.sparse import spmatrix
10 | from pandas import DataFrame
| ^^^^^^^^^ TCH002
11 |
12 | if TYPE_CHECKING:
|
= help: Move into type-checking block

Unsafe fix
7 7 | from numpy import asarray
8 8 | from numpy.typing import ArrayLike
9 9 | from scipy.sparse import spmatrix
10 |-from pandas import DataFrame
11 10 |
12 11 | if TYPE_CHECKING:
12 |+ from pandas import DataFrame
13 13 | from numpy import ndarray
14 14 |
15 15 |


0 comments on commit 1a4c324

Please sign in to comment.