Skip to content
This repository has been archived by the owner on Oct 4, 2024. It is now read-only.

Commit

Permalink
Make mypy_plugin compatible with mypy>=0.730
Browse files Browse the repository at this point in the history
* This fixes issue 21
* New type_analyze_hook to fix types before mypy's semantic analyzer runs
* _get_and_delete_cases defers if it encounters PlaceholderNodes
  • Loading branch information
wchresta committed Dec 28, 2019
1 parent 74c250f commit 22e523c
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 31 deletions.
9 changes: 7 additions & 2 deletions adt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .case import Case
from .decorator import adt
from typing import TYPE_CHECKING

from .case import Case
from .decorator import adt

if TYPE_CHECKING:
from .case import CaseConstructor
3 changes: 2 additions & 1 deletion adt/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def adt(cls):
f'Annotation {k} should be a Case[…] constructor, got {constructor!r} instead'
)

cls._Key = Enum('_Key', list(caseConstructors.keys()))
cls._Key = Enum('_Key', # mypy: warn-unused-ignores=False # type: ignore
list(caseConstructors.keys()))

_installInit(cls)
_installRepr(cls)
Expand Down
150 changes: 123 additions & 27 deletions adt/mypy_plugin.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,68 @@
# mypy: no-warn-redundant-casts
import itertools
import re
from decimal import Decimal
from typing import Any, Callable, Iterable, List, Optional, Type
from typing import Optional, Callable, List, Type, Any, Iterable, Union
import typing

import mypy.types
from mypy.nodes import (ARG_NAMED, ARG_POS, MDEF, Argument, AssignmentStmt,
Block, FuncDef, NameExpr, PassStmt, SymbolTableNode,
TypeVarExpr, Var)
from mypy.plugin import ClassDefContext, Plugin
import mypy.typevars
from mypy.nodes import (
ARG_NAMED,
ARG_POS,
MDEF,
Argument,
AssignmentStmt,
Block,
FuncDef,
FuncBase,
NameExpr,
PassStmt,
PlaceholderNode,
SymbolTableNode,
SymbolNode,
TypeVarExpr,
Var,
)
from mypy.plugin import AnalyzeTypeContext, TypeAnalyzerPluginInterface, ClassDefContext, Plugin
from mypy.semanal import set_callable_name
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name


# mypy plugin API hook
def plugin(version: str) -> Type[Plugin]:
assert Decimal(version) >= Decimal('0.711')
"""Return plugin class depending on mypy version."""
raw_version = version.split("+", 1)[0] # Handle development versions
assert Decimal(raw_version) >= Decimal("0.711")
return ADTPlugin


# fullname and name became properties with https://github.com/python/mypy/pull/7829
# These are compatibility shims
def get_fullname(x: Union[FuncBase, SymbolNode]) -> str:
fn = x.fullname
if callable(fn):
return typing.cast(str, fn())
return typing.cast(str, fn)


def get_name(x: Union[FuncBase, SymbolNode]) -> str:
fn = x.name
if callable(fn):
return typing.cast(str, fn())
return fn


class ADTPlugin(Plugin):
# Fully-qualified name for @adt
_ADT_DECORATOR = 'adt.decorator.adt'

# mypy plugin API hook
def get_type_analyze_hook(
self, fullname: str
) -> Optional[Callable[[AnalyzeTypeContext], mypy.types.Type]]:
if fullname == "adt.case.Case":
return _convert_case_type
return None

def get_class_decorator_hook(
self,
fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
Expand All @@ -33,6 +72,33 @@ def get_class_decorator_hook(
return _transform_class


def _convert_case_type(type_context: AnalyzeTypeContext) -> mypy.types.Type:
"""Convert Case[..] type to CaseConstructor[..]"""
# We do this because the semantic analyzer runs before the class_decorator_hook
# gets a chance to remove the Cases. This will convert the type into a valid
# runtime type and allows the class_decorator hook to be executed. The hook
# then removes this type from the ADT completely.
api: TypeAnalyzerPluginInterface = type_context.api
type_to_convert: mypy.types.UnboundType = type_context.type

call_args = type_to_convert.args
function_type = type_context.api.named_type("builtins.function", [])

arg_types = list(map(api.analyze_type, call_args))
arg_kinds = [mypy.types.ARG_POS for _ in call_args]
arg_names = [None for _ in call_args]

return_type = api.named_type("adt.CaseConstructor", arg_types)

return mypy.types.CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=return_type,
fallback=function_type,
)


class _CaseDef:
context: ClassDefContext
name: str
Expand Down Expand Up @@ -112,6 +178,9 @@ def _transform_class(context: ClassDefContext) -> None:
assert isinstance(instanceType, mypy.types.Instance)

cases = _get_and_delete_cases(context)
if cases is None: # Cases were not successfully deleted. We need to defer
context.api.defer()
return

for case in cases:
_add_constructor_for_case(context, case, selfType=instanceType)
Expand All @@ -123,31 +192,58 @@ def _transform_class(context: ClassDefContext) -> None:
# Returns ADT cases which were listed as class variables (similar to
# cls.__annotations__ at runtime), and removes those variables from
# typechecking, as they will be replaced by constructor methods.
def _get_and_delete_cases(context: ClassDefContext) -> List[_CaseDef]:
def _get_and_delete_cases(context: ClassDefContext
) -> Optional[List[_CaseDef]]:
"""Search the class body for adt's Case constructions and delete them
For a given context, search the class body for assignments of the form
`CASENAME: Case[...]`. Delete them, and return a _CaseDef for each.
In case the body is not ready (because the semantic analyzer included a
PlaceHolder expression), this function will return None and is expected
to be called again.
If no PlaceHolder is found, return a list of _CaseDef.
"""
cls = context.cls

caseDefs: List[_CaseDef] = []
removed: List[int] = []
for i, statement in enumerate(cls.defs.body):
if not isinstance(statement, AssignmentStmt):
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
if not (isinstance(statement, AssignmentStmt)
and statement.new_syntax):
continue

for lval in statement.lvalues:
if not isinstance(lval, NameExpr):
continue
# a: int, b: str = 1, 'foo' is not supported syntax so we
# don't have to worry about it.
lval = statement.lvalues[0]
if not isinstance(lval, NameExpr):
continue

sym = cls.info.names.get(lval.name)
if sym is None:
# This name is likely blocked by a star import. We don't need to defer because
# defer() is already called by mark_incomplete().
continue

var = cls.info[lval.name].node
if not isinstance(var, Var):
continue
var = sym.node
if isinstance(var, PlaceholderNode):
# This node is not ready yet.
return None

assert isinstance(var.type, mypy.types.Instance)
assert re.match(r'^adt.case.Case(T)?$',
var.type.type.defn.fullname)
assert isinstance(var, Var)
assert isinstance(var.type, mypy.types.CallableType)
assert isinstance(var.type.ret_type, mypy.types.Instance)
assert get_fullname(
var.type.ret_type.type) == "adt.case.CaseConstructor"

caseDefs.append(
_CaseDef(context=context, name=var.name(),
types=var.type.args))
removed.append(i)
caseDefs.append(
_CaseDef(context=context,
name=get_name(var),
types=var.type.ret_type.args))
removed.append(i)

for i in reversed(removed):
del cls.defs.body[i]
Expand Down Expand Up @@ -204,7 +300,7 @@ def _add_match(context: ClassDefContext, cases: Iterable[_CaseDef]) -> None:
def _add_typevar(context: ClassDefContext,
tVarName: str) -> mypy.types.TypeVarDef:
typeInfo = context.cls.info
tVarQualifiedName = f'{typeInfo.fullname()}.{tVarName}'
tVarQualifiedName = f'{get_fullname(typeInfo)}.{tVarName}'
objectType = context.api.named_type('__builtins__.object')

tVarExpr = TypeVarExpr(tVarName, tVarQualifiedName, [], objectType)
Expand Down Expand Up @@ -269,7 +365,7 @@ def _add_method(
for arg in args:
assert arg.type_annotation, 'All arguments must be fully typed.'
arg_types.append(arg.type_annotation)
arg_names.append(arg.variable.name())
arg_names.append(get_name(arg.variable))
arg_kinds.append(arg.kind)

signature = mypy.types.CallableType(arg_types, arg_kinds, arg_names,
Expand All @@ -281,7 +377,7 @@ def _add_method(
func.info = info
func.is_class = is_classmethod
func.type = set_callable_name(signature, func)
func._fullname = info.fullname() + '.' + name
func._fullname = get_fullname(info) + '.' + name
func.line = info.line

# NOTE: we would like the plugin generated node to dominate, but we still
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
yapf==0.27.0
mypy==0.711
mypy>=0.711
coverage==4.5.3
hypothesis==4.24.5
coveralls==1.8.1

0 comments on commit 22e523c

Please sign in to comment.