diff --git a/docs/source/codemods.rst b/docs/source/codemods.rst index 4e5634723..3711a8f32 100644 --- a/docs/source/codemods.rst +++ b/docs/source/codemods.rst @@ -153,3 +153,5 @@ inside codemods. As of now, the list includes the following helpers. :exclude-members: CONTEXT_KEY, visit_Module, leave_ImportFrom, leave_Module .. autoclass:: libcst.codemod.visitors.RemoveImportsVisitor :exclude-members: CONTEXT_KEY, METADATA_DEPENDENCIES, visit_Module, leave_ImportFrom, leave_Import +.. autoclass:: libcst.codemod.visitors.ApplyTypeAnnotationsVisitor + :exclude-members: CONTEXT_KEY, transform_module_impl, visit_ClassDef, visit_Comment, visit_FunctionDef, leave_Assign, leave_ClassDef, leave_FunctionDef, leave_ImportFrom, leave_Module diff --git a/libcst/codemod/visitors/__init__.py b/libcst/codemod/visitors/__init__.py index ac3bec554..83deb2513 100644 --- a/libcst/codemod/visitors/__init__.py +++ b/libcst/codemod/visitors/__init__.py @@ -5,6 +5,7 @@ # # pyre-strict from libcst.codemod.visitors._add_imports import AddImportsVisitor +from libcst.codemod.visitors._apply_type_annotations import ApplyTypeAnnotationsVisitor from libcst.codemod.visitors._gather_exports import GatherExportsVisitor from libcst.codemod.visitors._gather_imports import GatherImportsVisitor from libcst.codemod.visitors._remove_imports import RemoveImportsVisitor @@ -14,5 +15,6 @@ "AddImportsVisitor", "GatherImportsVisitor", "GatherExportsVisitor", + "ApplyTypeAnnotationsVisitor", "RemoveImportsVisitor", ] diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py new file mode 100644 index 000000000..85ee2be61 --- /dev/null +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -0,0 +1,481 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree +# +# pyre-strict + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union + +import libcst as cst +from libcst.codemod._context import CodemodContext +from libcst.codemod._visitor import ContextAwareTransformer +from libcst.codemod.visitors._gather_imports import GatherImportsVisitor +from libcst.codemod.visitors._add_imports import AddImportsVisitor +from libcst.helpers import get_full_name_for_node + +def _get_import_alias_names(import_aliases: Sequence[cst.ImportAlias]) -> Set[str]: + import_names = set() + for imported_name in import_aliases: + asname = imported_name.asname + if asname: + import_names.add(get_full_name_for_node(asname.name)) + else: + import_names.add(get_full_name_for_node(imported_name.name)) + return import_names + + +def _get_import_names(imports: Sequence[Union[cst.Import, cst.ImportFrom]]) -> Set[str]: + import_names = set() + for _import in imports: + if isinstance(_import, cst.Import): + import_names.update(_get_import_alias_names(_import.names)) + else: + names = _import.names + if not isinstance(names, cst.ImportStar): + import_names.update(_get_import_alias_names(names)) + return import_names + + +@dataclass(frozen=True) +class FunctionAnnotation: + parameters: cst.Parameters + returns: Optional[cst.Annotation] + + +class TypeCollector(cst.CSTVisitor): + """ + Collect type annotations from a stub module. + """ + + def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None: + # Qualifier for storing the canonical name of the current function. + self.qualifier: List[str] = [] + # Store the annotations. + self.function_annotations: Dict[str, FunctionAnnotation] = {} + self.attribute_annotations: Dict[str, cst.Annotation] = {} + self.existing_imports: Set[str] = existing_imports + self.class_definitions: Dict[str, cst.ClassDef] = {} + self.context = context + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.qualifier.append(node.name.value) + self.class_definitions[node.name.value] = node + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.qualifier.pop() + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + self.qualifier.append(node.name.value) + returns = node.returns + if returns is not None: + return_annotation = self._create_import_from_annotation(returns) + parameter_annotations = self._import_parameter_annotations(node.params) + self.function_annotations[".".join(self.qualifier)] = FunctionAnnotation( + parameters=parameter_annotations, returns=return_annotation + ) + # pyi files don't support inner functions, return False to stop the traversal. + return False + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.qualifier.pop() + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: + name = get_full_name_for_node(node.target) + if name is not None: + self.qualifier.append(name) + annotation_value = self._create_import_from_annotation(node.annotation) + self.attribute_annotations[".".join(self.qualifier)] = annotation_value + return True + + def leave_AnnAssign(self, original_node: cst.AnnAssign) -> None: + self.qualifier.pop() + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + module = node.module + names = node.names + + # module is None for relative imports like `from .. import foo`. + # We ignore these for now. + if module is None or isinstance(names, cst.ImportStar): + return + module_name = get_full_name_for_node(module) + if module_name is not None: + for import_name in _get_import_alias_names(names): + AddImportsVisitor.add_needed_import(self.context, module_name, import_name) + + def _add_annotation_to_imports( + self, annotation: cst.Attribute + ) -> Union[cst.Name, cst.Attribute]: + key = get_full_name_for_node(annotation.value) + if key is not None: + # Don't attempt to re-import existing imports. + if key in self.existing_imports: + return annotation + import_name = get_full_name_for_node(annotation.attr) + if import_name is not None: + AddImportsVisitor.add_needed_import(self.context, key, import_name) + return annotation.attr + + def _handle_Index(self, slice: cst.Index, node: cst.Subscript) -> cst.Subscript: + value = slice.value + if isinstance(value, cst.Subscript): + new_slice = slice.with_changes(value=self._handle_Subscript(value)) + return node.with_changes(slice=new_slice) + elif isinstance(value, cst.Attribute): + new_slice = slice.with_changes(value=self._add_annotation_to_imports(value)) + return node.with_changes(slice=new_slice) + else: + return node + + def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript: + slice = node.slice + if isinstance(slice, list): + new_slice = [] + for item in slice: + value = item.slice.value + if isinstance(value, cst.Attribute): + name = self._add_annotation_to_imports(item.slice.value) + new_index = item.slice.with_changes(value=name) + new_slice.append(item.with_changes(slice=new_index)) + else: + if isinstance(item.slice, cst.Index) and not isinstance( + item.slice.value, cst.Name + ): + new_index = item.slice.with_changes( + value=self._handle_Index(item.slice, item) + ) + item = item.with_changes(slice=new_index, comma=None) + new_slice.append(item) + return node.with_changes(slice=new_slice) + elif isinstance(slice, cst.Index): + return self._handle_Index(slice, node) + else: + return node + + def _create_import_from_annotation(self, returns: cst.Annotation) -> cst.Annotation: + annotation = returns.annotation + if isinstance(annotation, cst.Attribute): + attr = self._add_annotation_to_imports(annotation) + return cst.Annotation(annotation=attr) + if isinstance(annotation, cst.Subscript): + value = annotation.value + if isinstance(value, cst.Name) and value.value == "Type": + return returns + return cst.Annotation(annotation=self._handle_Subscript(annotation)) + else: + return returns + + def _import_parameter_annotations( + self, parameters: cst.Parameters + ) -> cst.Parameters: + def update_annotations(parameters: Sequence[cst.Param]) -> List[cst.Param]: + updated_parameters = [] + for parameter in list(parameters): + annotation = parameter.annotation + if annotation is not None: + parameter = parameter.with_changes( + annotation=self._create_import_from_annotation(annotation) + ) + updated_parameters.append(parameter) + return updated_parameters + + return parameters.with_changes(params=update_annotations(parameters.params)) + + +@dataclass(frozen=True) +class Annotations: + function_annotations: Dict[str, FunctionAnnotation] = field(default_factory=dict) + attribute_annotations: Dict[str, cst.Annotation] = field(default_factory=dict) + class_definitions: Dict[str, cst.ClassDef] = field(default_factory=dict) + + +class ApplyTypeAnnotationsVisitor(ContextAwareTransformer): + """ + Apply type annotations to a source module using the given stub mdules. + You can also pass in explicit annotations for functions and attributes and + pass in new class definitions that need to be added to the source module. + + This is one of the transforms that is available automatically to you when + running a codemod. To use it in this manner, import + :class:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor` and then call the static + :meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.add_stub_to_context` method, + giving it the current context (found as ``self.context`` for all subclasses of + :class:`~libcst.codemod.Codemod`), the stub module from which you wish to add annotations. + + For example, you can store the type annotation ``int`` for ``x`` using:: + + stub_module = parse_module("x: int = ...") + + ApplyTypeAnnotationsVisitor.add_stub_to_context(self.context, stub_module) + + You can apply the type annotation using:: + + source_module = parse_module("x = 1") + ApplyTypeAnnotationsVisitor.transform_module(source_module) + + This will produce the following code:: + + x: int = 1 + + If the function or attribute already has a type annotation, it will not be overwritten. + """ + + CONTEXT_KEY = "ApplyTypeAnnotationsVisitor" + + def __init__( + self, context: CodemodContext, annotations: Optional[Annotations] = None + ) -> None: + super().__init__(context) + # Qualifier for storing the canonical name of the current function. + self.qualifier: List[str] = [] + self.annotations: Annotations = annotations or Annotations() + self.toplevel_annotations: Dict[str, cst.Annotation] = {} + self.visited_classes: Set[str] = set() + + # We use this to determine the end of the import block so that we can + # insert top-level annotations. + self.import_statements: List[cst.ImportFrom] = [] + self.is_generated: bool = False + + @staticmethod + def add_stub_to_context(context: CodemodContext, stub: cst.Module) -> None: + """ + Add a stub module to the :class:`~libcst.codemod.CodemodContext` so + that type annotations from the stub can be applied in a later + invocation of this class. + """ + context.scratch.setdefault(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, []).append( + stub + ) + + def transform_module_impl(self, tree: cst.Module) -> cst.Module: + """ + Collect type annotations from all stubs and apply them to ``tree``. + + Gather existing imports from ``tree`` so that we don't add duplicate imports. + """ + import_gatherer = GatherImportsVisitor(CodemodContext()) + tree.visit(import_gatherer) + existing_import_names = _get_import_names(import_gatherer.all_imports) + + stubs = self.context.scratch.get(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, []) + for stub in stubs: + visitor = TypeCollector(existing_import_names, self.context) + stub.visit(visitor) + self.annotations.function_annotations.update(visitor.function_annotations) + self.annotations.attribute_annotations.update(visitor.attribute_annotations) + self.annotations.class_definitions.update(visitor.class_definitions) + + tree_with_imports = AddImportsVisitor(self.context).transform_module(tree) + return tree_with_imports.visit(self) + + def _qualifier_name(self) -> str: + return ".".join(self.qualifier) + + def _annotate_single_target( + self, node: cst.Assign, updated_node: cst.Assign + ) -> Union[cst.Assign, cst.AnnAssign]: + only_target = node.targets[0].target + if isinstance(only_target, (cst.Tuple, cst.List)): + for element in only_target.elements: + value = element.value + name = get_full_name_for_node(value) + if name: + self._add_to_toplevel_annotations(name) + elif isinstance(only_target, (cst.Subscript)): + pass + else: + name = get_full_name_for_node(only_target) + if name is not None: + self.qualifier.append(name) + if self._qualifier_name() in self.annotations.attribute_annotations and not isinstance( + only_target, cst.Subscript + ): + annotation = self.annotations.attribute_annotations[ + self._qualifier_name() + ] + self.qualifier.pop() + return cst.AnnAssign(cst.Name(name), annotation, node.value) + else: + self.qualifier.pop() + return updated_node + + def _split_module( + self, module: cst.Module, updated_module: cst.Module + ) -> Tuple[ + List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], + List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], + ]: + import_add_location = 0 + # This works under the principle that while we might modify node contents, + # we have yet to modify the number of statements. So we can match on the + # original tree but break up the statements of the modified tree. If we + # change this assumption in this visitor, we will have to change this code. + for i, statement in enumerate(module.body): + if isinstance(statement, cst.SimpleStatementLine): + for possible_import in statement.body: + for last_import in self.import_statements: + if possible_import is last_import: + import_add_location = i + 1 + break + + return ( + list(updated_module.body[:import_add_location]), + list(updated_module.body[import_add_location:]), + ) + + def _add_to_toplevel_annotations(self, name: str) -> None: + self.qualifier.append(name) + if self._qualifier_name() in self.annotations.attribute_annotations: + annotation = self.annotations.attribute_annotations[self._qualifier_name()] + self.toplevel_annotations[name] = annotation + self.qualifier.pop() + + def _update_parameters( + self, annotations: FunctionAnnotation, updated_node: cst.FunctionDef + ) -> cst.Parameters: + # Update params and default params with annotations + # don't override existing annotations or default values + def update_annotation( + parameters: Sequence[cst.Param], annotations: Sequence[cst.Param] + ) -> List[cst.Param]: + parameter_annotations = {} + annotated_parameters = [] + for parameter in annotations: + if parameter.annotation: + parameter_annotations[parameter.name.value] = parameter.annotation + for parameter in parameters: + key = parameter.name.value + if key in parameter_annotations and not parameter.annotation: + parameter = parameter.with_changes( + annotation=parameter_annotations[key] + ) + annotated_parameters.append(parameter) + return annotated_parameters + + return annotations.parameters.with_changes( + params=update_annotation( + updated_node.params.params, annotations.parameters.params + ) + ) + + def _insert_empty_line( + self, + statements: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], + ) -> List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]]: + if len(statements) < 1: + # No statements, nothing to add to + return statements + if len(statements[0].leading_lines) == 0: + # Statement has no leading lines, add one! + return [ + statements[0].with_changes(leading_lines=(cst.EmptyLine(),)), + *statements[1:], + ] + if statements[0].leading_lines[0].comment is None: + # First line is empty, so its safe to leave as-is + return statements + # Statement has a comment first line, so lets add one more empty line + return [ + statements[0].with_changes( + leading_lines=(cst.EmptyLine(), *statements[0].leading_lines) + ), + *statements[1:], + ] + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.qualifier.append(node.name.value) + self.visited_classes.add(node.name.value) + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + self.qualifier.pop() + return updated_node + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + self.qualifier.append(node.name.value) + # pyi files don't support inner functions, return False to stop the traversal. + return False + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + key = self._qualifier_name() + self.qualifier.pop() + if key in self.annotations.function_annotations: + function_annotation = self.annotations.function_annotations[key] + # Only add new annotation if one doesn't already exist + if not updated_node.returns: + updated_node = updated_node.with_changes(returns=function_annotation.returns) + # Don't override default values when annotating functions + new_parameters = self._update_parameters(function_annotation, updated_node) + return updated_node.with_changes(params=new_parameters) + return updated_node + + def leave_Assign( + self, original_node: cst.Assign, updated_node: cst.Assign + ) -> Union[cst.Assign, cst.AnnAssign]: + + if len(original_node.targets) > 1: + for assign in original_node.targets: + target = assign.target + if isinstance(target, (cst.Name, cst.Attribute)): + name = get_full_name_for_node(target) + if name is not None: + # Add separate top-level annotations for `a = b = 1` + # as `a: int` and `b: int`. + self._add_to_toplevel_annotations(name) + return updated_node + else: + return self._annotate_single_target(original_node, updated_node) + + def leave_ImportFrom( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> cst.ImportFrom: + self.import_statements.append(original_node) + return updated_node + + def visit_Comment(self, node: cst.Comment) -> None: + if "@" "generated" in node.value: + self.is_generated = True + + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: + fresh_class_definitions = [ + definition + for name, definition in self.annotations.class_definitions.items() + if name not in self.visited_classes + ] + if self.is_generated: + return original_node + if ( + not self.toplevel_annotations + and not fresh_class_definitions + ): + return updated_node + toplevel_statements = [] + # First, find the insertion point for imports + statements_before_imports, statements_after_imports = self._split_module( + original_node, updated_node + ) + + # Make sure there's at least one empty line before the first non-import + statements_after_imports = self._insert_empty_line(statements_after_imports) + + for name, annotation in self.toplevel_annotations.items(): + annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None) + toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) + + toplevel_statements.extend(fresh_class_definitions) + + return updated_node.with_changes( + body=[ + *statements_before_imports, + *toplevel_statements, + *statements_after_imports, + ] + ) diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py new file mode 100644 index 000000000..9357569c0 --- /dev/null +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -0,0 +1,630 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict + +import textwrap +from typing import Type + +import libcst as cst +from libcst import parse_module +from libcst.codemod import Codemod, CodemodContext, CodemodTest +from libcst.codemod.visitors._apply_type_annotations import ApplyTypeAnnotationsVisitor +from libcst.testing.utils import data_provider + + +class TestApplyAnnotationsVisitor(CodemodTest): + TRANSFORM: Type[Codemod] = ApplyTypeAnnotationsVisitor + + @data_provider( + ( + ( + """ + def foo() -> int: ... + """, + """ + def foo(): + return 1 + """, + """ + def foo() -> int: + return 1 + """, + ), + ( + """ + import bar + + def foo() -> bar.Baz: ... + """, + """ + def foo(): + return returns_baz() + """, + """ + from bar import Baz + + def foo() -> Baz: + return returns_baz() + """, + ), + # Keep the existing `import A` instead of using `from A import B`. + ( + """ + import bar + + def foo() -> bar.Baz: ... + """, + """ + import bar + + def foo(): + return returns_baz() + """, + """ + import bar + + def foo() -> bar.Baz: + return returns_baz() + """, + ), + ( + """ + def foo() -> int: ... + + class A: + def foo() -> str: ... + """, + """ + def foo(): + return 1 + class A: + def foo(): + return '' + """, + """ + def foo() -> int: + return 1 + class A: + def foo() -> str: + return '' + """, + ), + ( + """ + bar: int = ... + """, + """ + bar = foo() + """, + """ + bar: int = foo() + """, + ), + ( + """ + bar: int = ... + """, + """ + bar: str = foo() + """, + """ + bar: str = foo() + """, + ), + ( + """ + bar: int = ... + class A: + bar: str = ... + """, + """ + bar = foo() + class A: + bar = foobar() + """, + """ + bar: int = foo() + class A: + bar: str = foobar() + """, + ), + ( + """ + bar: int = ... + class A: + bar: str = ... + """, + """ + bar = foo() + class A: + bar = foobar() + """, + """ + bar: int = foo() + class A: + bar: str = foobar() + """, + ), + ( + """ + a: int = ... + b: str = ... + """, + """ + def foo() -> Tuple[int, str]: + return (1, "") + + a, b = foo() + """, + """ + a: int + b: str + + def foo() -> Tuple[int, str]: + return (1, "") + + a, b = foo() + """, + ), + ( + """ + a: int = ... + b: str = ... + """, + """ + def foo() -> Tuple[int, str]: + return (1, "") + + [a, b] = foo() + """, + """ + a: int + b: str + + def foo() -> Tuple[int, str]: + return (1, "") + + [a, b] = foo() + """, + ), + ( + """ + x: int = ... + y: int = ... + z: int = ... + """, + """ + x = y = z = 1 + """, + """ + x: int + y: int + z: int + + x = y = z = 1 + """, + ), + # Don't add annotations if one is already present + ( + """ + def foo(x: int = 1) -> List[str]: ... + """, + """ + from typing import Iterable, Any + + def foo(x = 1) -> Iterable[Any]: + return [''] + """, + """ + from typing import Iterable, Any + + def foo(x: int = 1) -> Iterable[Any]: + return [''] + """, + ), + ( + """ + from typing import List + + def foo() -> List[int]: ... + """, + """ + def foo(): + return [1] + """, + """ + from typing import List + + def foo() -> List[int]: + return [1] + """, + ), + ( + """ + from typing import List + + def foo() -> List[int]: ... + """, + """ + from typing import Union + + def foo(): + return [1] + """, + """ + from typing import List, Union + + def foo() -> List[int]: + return [1] + """, + ), + ( + """ + a: Dict[str, int] = ... + """, + """ + def foo() -> int: + return 1 + a = {} + a['x'] = foo() + """, + """ + def foo() -> int: + return 1 + a: Dict[str, int] = {} + a['x'] = foo() + """, + ), + # Test that tuples with subscripts are handled correctly + # and top level annotations are added in the correct place + ( + """ + a: int = ... + """, + """ + from typing import Tuple + + def foo() -> Tuple[str, int]: + return "", 1 + + b['z'], a = foo() + """, + """ + from typing import Tuple + a: int + + def foo() -> Tuple[str, int]: + return "", 1 + + b['z'], a = foo() + """, + ), + # Don't override existing default parameter values + ( + """ + class B: + def foo(self, x: int = a.b.A.__add__(1), y=None) -> int: ... + """, + """ + class B: + def foo(self, x = A + 1, y = None) -> int: + return x + + """, + """ + class B: + def foo(self, x: int = A + 1, y = None) -> int: + return x + """, + ), + ( + """ + def foo(x: int) -> int: ... + """, + """ + def foo(x) -> int: + return x + """, + """ + def foo(x: int) -> int: + return x + """, + ), + ( + """ + async def a(r: Request, z=None) -> django.http.response.HttpResponse: ... + async def b(r: Request, z=None) -> django.http.response.HttpResponse: ... + async def c(r: Request, z=None) -> django.http.response.HttpResponse: ... + """, + """ + async def a(r: Request, z=None): ... + async def b(r: Request, z=None): ... + async def c(r: Request, z=None): ... + """, + """ + from django.http.response import HttpResponse + + async def a(r: Request, z=None) -> HttpResponse: ... + async def b(r: Request, z=None) -> HttpResponse: ... + async def c(r: Request, z=None) -> HttpResponse: ... + """, + ), + ( + """ + FOO: a.b.Example = ... + """, + """ + FOO = bar() + """, + """ + from a.b import Example + + FOO: Example = bar() + """, + ), + ( + """ + FOO: Union[a.b.Example, int] = ... + """, + """ + FOO = bar() + """, + """ + from a.b import Example + + FOO: Union[Example, int] = bar() + """, + ), + ( + """ + def foo(x: int) -> List[Union[a.b.Example, str]]: ... + """, + """ + def foo(x: int): + return [barfoo(), ""] + """, + """ + from a.b import Example + + def foo(x: int) -> List[Union[Example, str]]: + return [barfoo(), ""] + """, + ), + ( + """ + def foo(x: int) -> Optional[a.b.Example]: ... + """, + """ + def foo(x: int): + pass + """, + """ + from a.b import Example + + def foo(x: int) -> Optional[Example]: + pass + """, + ), + ( + """ + def foo(x: int) -> str: ... + """, + """ + def foo(x: str): + pass + """, + """ + def foo(x: str) -> str: + pass + """, + ), + ( + """ + def foo(x: int)-> Union[ + Coroutine[Any, Any, django.http.response.HttpResponse], str + ]: + ... + """, + """ + def foo(x: int): + pass + """, + """ + from django.http.response import HttpResponse + + def foo(x: int) -> Union[ + Coroutine[Any, Any, HttpResponse], str + ]: + pass + """, + ), + ( + """ + def foo(x: django.http.response.HttpResponse) -> str: + pass + """, + """ + def foo(x) -> str: + pass + """, + """ + from django.http.response import HttpResponse + + def foo(x: HttpResponse) -> str: + pass + """, + ), + ( + """ + def foo() -> b.b.A: ... + """, + """ + from c import A as B, bar + + def foo(): + return bar() + """, + """ + from c import A as B, bar + from b.b import A + + def foo() -> A: + return bar() + """, + ), + ( + """ + def foo() -> int: ... + """, + f""" + # @generated + def foo(): + return 1 + """, + f""" + # @generated + def foo(): + return 1 + """, + ), + ( + """ + from typing import Type + + def foo() -> Type[foo.A]: ... + """, + """ + def foo(): + class A: + x = 1 + return A + + """, + """ + from typing import Type + + def foo() -> Type[foo.A]: + class A: + x = 1 + return A + """, + ), + ( + """ + def foo() -> db.Connection: ... + """, + """ + import my.cool.db as db + def foo(): + return db.Connection() + """, + """ + import my.cool.db as db + def foo() -> db.Connection: + return db.Connection() + """, + ), + ( + """ + def foo() -> typing.Sequence[int]: ... + """, + """ + import typing + def foo(): + return [] + """, + """ + import typing + def foo() -> typing.Sequence[int]: + return [] + """, + ), + # Insert a TypedDict class that is not in the source file. + ( + """ + from mypy_extensions import TypedDict + + class MovieTypedDict(TypedDict): + name: str + year: int + """, + """ + def foo() -> None: + pass + """, + """ + from mypy_extensions import TypedDict + + class MovieTypedDict(TypedDict): + name: str + year: int + + def foo() -> None: + pass + """, + ), + # Insert only the TypedDict class that is not in the source file. + ( + """ + from mypy_extensions import TypedDict + + class MovieTypedDict(TypedDict): + name: str + year: int + + class ExistingMovieTypedDict(TypedDict): + name: str + year: int + """, + """ + from mypy_extensions import TypedDict + + class ExistingMovieTypedDict(TypedDict): + name: str + year: int + + def foo() -> None: + pass + """, + """ + from mypy_extensions import TypedDict + + class MovieTypedDict(TypedDict): + name: str + year: int + + class ExistingMovieTypedDict(TypedDict): + name: str + year: int + + def foo() -> None: + pass + """, + ), + # Sanity check that we don't fail when the stub has relative imports. + # We don't do anything with those imports, though. + ( + """ + from .. import hello + def foo() -> typing.Sequence[int]: ... + """, + """ + def foo(): + return [] + """, + """ + def foo() -> typing.Sequence[int]: + return [] + """, + ), + ) + ) + def test_annotate_functions(self, stub: str, before: str, after: str) -> None: + context = CodemodContext() + ApplyTypeAnnotationsVisitor.add_stub_to_context( + context, parse_module(textwrap.dedent(stub.rstrip())) + ) + self.assertCodemod(before, after, context_override=context)