diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py index a05c46f9b6..bfdbbe09e5 100644 --- a/astroid/brain/brain_dataclasses.py +++ b/astroid/brain/brain_dataclasses.py @@ -10,7 +10,8 @@ - https://lovasoa.github.io/marshmallow_dataclass/ """ -from typing import FrozenSet, Generator, List, Optional, Tuple +import sys +from typing import FrozenSet, Generator, List, Optional, Tuple, Union from astroid import context, inference_tip from astroid.builder import parse @@ -36,6 +37,15 @@ from astroid.nodes.scoped_nodes import ClassDef, FunctionDef from astroid.util import Uninferable +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +_FieldDefaultReturn = Union[ + None, Tuple[Literal["default"], NodeNG], Tuple[Literal["default_factory"], Call] +] + DATACLASSES_DECORATORS = frozenset(("dataclass",)) FIELD_NAME = "field" DATACLASS_MODULES = frozenset( @@ -115,7 +125,7 @@ def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator: ): continue - if _is_class_var(assign_node.annotation): + if _is_class_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None continue if init: @@ -124,12 +134,13 @@ def _get_dataclass_attributes(node: ClassDef, init: bool = False) -> Generator: isinstance(value, Call) and _looks_like_dataclass_field_call(value, check_scope=False) and any( - keyword.arg == "init" and not keyword.value.bool_value() + keyword.arg == "init" + and not keyword.value.bool_value() # type: ignore[union-attr] # value is never None for keyword in value.keywords ) ): continue - elif _is_init_var(assign_node.annotation): + elif _is_init_var(assign_node.annotation): # type: ignore[arg-type] # annotation is never None continue yield assign_node @@ -159,7 +170,8 @@ def _check_generate_dataclass_init(node: ClassDef) -> bool: # Check for keyword arguments of the form init=False return all( - keyword.arg != "init" or keyword.value.bool_value() + keyword.arg != "init" + and keyword.value.bool_value() # type: ignore[union-attr] # value is never None for keyword in found.keywords ) @@ -174,7 +186,7 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str: name, annotation, value = assign.target.name, assign.annotation, assign.value target_names.append(name) - if _is_init_var(annotation): + if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None init_var = True if isinstance(annotation, Subscript): annotation = annotation.slice @@ -196,16 +208,16 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str: value, check_scope=False ): result = _get_field_default(value) - - default_type, default_node = result - if default_type == "default": - param_str += f" = {default_node.as_string()}" - elif default_type == "default_factory": - param_str += f" = {DEFAULT_FACTORY}" - assignment_str = ( - f"self.{name} = {default_node.as_string()} " - f"if {name} is {DEFAULT_FACTORY} else {name}" - ) + if result: + default_type, default_node = result + if default_type == "default": + param_str += f" = {default_node.as_string()}" + elif default_type == "default_factory": + param_str += f" = {DEFAULT_FACTORY}" + assignment_str = ( + f"self.{name} = {default_node.as_string()} " + f"if {name} is {DEFAULT_FACTORY} else {name}" + ) else: param_str += f" = {value.as_string()}" @@ -219,7 +231,7 @@ def _generate_dataclass_init(assigns: List[AnnAssign]) -> str: def infer_dataclass_attribute( - node: Unknown, ctx: context.InferenceContext = None + node: Unknown, ctx: Optional[context.InferenceContext] = None ) -> Generator: """Inference tip for an Unknown node that was dynamically generated to represent a dataclass attribute. @@ -247,16 +259,17 @@ def infer_dataclass_field_call( """Inference tip for dataclass field calls.""" if not isinstance(node.parent, (AnnAssign, Assign)): raise UseInferenceDefault - field_call = node.parent.value - default_type, default = _get_field_default(field_call) - if not default_type: + result = _get_field_default(node) + if not result: yield Uninferable - elif default_type == "default": - yield from default.infer(context=ctx) else: - new_call = parse(default.as_string()).body[0].value - new_call.parent = field_call.parent - yield from new_call.infer(context=ctx) + default_type, default = result + if default_type == "default": + yield from default.infer(context=ctx) + else: + new_call = parse(default.as_string()).body[0].value + new_call.parent = node.parent + yield from new_call.infer(context=ctx) def _looks_like_dataclass_decorator( @@ -294,6 +307,9 @@ def _looks_like_dataclass_attribute(node: Unknown) -> bool: statement. """ parent = node.parent + if not parent: + return False + scope = parent.scope() return ( isinstance(parent, AnnAssign) @@ -330,7 +346,7 @@ def _looks_like_dataclass_field_call(node: Call, check_scope: bool = True) -> bo return inferred.name == FIELD_NAME and inferred.root().name in DATACLASS_MODULES -def _get_field_default(field_call: Call) -> Tuple[str, Optional[NodeNG]]: +def _get_field_default(field_call: Call) -> _FieldDefaultReturn: """Return a the default value of a field call, and the corresponding keyword argument name. field(default=...) results in the ... node @@ -358,7 +374,7 @@ def _get_field_default(field_call: Call) -> Tuple[str, Optional[NodeNG]]: new_call.postinit(func=default_factory) return "default_factory", new_call - return "", None + return None def _is_class_var(node: NodeNG) -> bool: @@ -404,7 +420,7 @@ def _is_init_var(node: NodeNG) -> bool: def _infer_instance_from_annotation( - node: NodeNG, ctx: context.InferenceContext = None + node: NodeNG, ctx: Optional[context.InferenceContext] = None ) -> Generator: """Infer an instance corresponding to the type annotation represented by node.