diff --git a/mypy/semanal.py b/mypy/semanal.py index f2380cf4b1466..d7694e538f806 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -44,7 +44,7 @@ """ from typing import ( - List, Dict, Set, Tuple, cast, Any, TypeVar, Union, Optional, Callable + List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable ) from mypy.nodes import ( @@ -280,9 +280,7 @@ def visit_func_def(self, defn: FuncDef) -> None: if defn.name() in self.type.names: # Redefinition. Conditional redefinition is okay. n = self.type.names[defn.name()].node - if self.is_conditional_func(n, defn): - defn.original_def = cast(FuncDef, n) - else: + if not self.set_original_def(n, defn): self.name_already_defined(defn.name(), defn) self.type.names[defn.name()] = SymbolTableNode(MDEF, defn) self.prepare_method_signature(defn) @@ -292,9 +290,7 @@ def visit_func_def(self, defn: FuncDef) -> None: if defn.name() in self.locals[-1]: # Redefinition. Conditional redefinition is okay. n = self.locals[-1][defn.name()].node - if self.is_conditional_func(n, defn): - defn.original_def = cast(FuncDef, n) - else: + if not self.set_original_def(n, defn): self.name_already_defined(defn.name(), defn) else: self.add_local(defn, defn) @@ -304,11 +300,7 @@ def visit_func_def(self, defn: FuncDef) -> None: symbol = self.globals.get(defn.name()) if isinstance(symbol.node, FuncDef) and symbol.node != defn: # This is redefinition. Conditional redefinition is okay. - original_def = symbol.node - if self.is_conditional_func(original_def, defn): - # Conditional function definition -- multiple defs are ok. - defn.original_def = original_def - else: + if not self.set_original_def(symbol.node, defn): # Report error. self.check_no_global(defn.name(), defn, True) if phase_info == FUNCTION_FIRST_PHASE_POSTPONE_SECOND: @@ -341,11 +333,10 @@ def prepare_method_signature(self, func: FuncDef) -> None: leading_type = self.class_type(self.type) else: leading_type = fill_typevars(self.type) - sig = cast(FunctionLike, func.type) - func.type = replace_implicit_first_type(sig, leading_type) + func.type = replace_implicit_first_type(functype, leading_type) - def is_conditional_func(self, previous: Node, new: FuncDef) -> bool: - """Does 'new' conditionally redefine 'previous'? + def set_original_def(self, previous: Node, new: FuncDef) -> bool: + """If 'new' conditionally redefine 'previous', set 'previous' as original We reject straight redefinitions of functions, as they are usually a programming error. For example: @@ -353,7 +344,11 @@ def is_conditional_func(self, previous: Node, new: FuncDef) -> bool: . def f(): ... . def f(): ... # Error: 'f' redefined """ - return isinstance(previous, (FuncDef, Var)) and new.is_conditional + if isinstance(previous, (FuncDef, Var)) and new.is_conditional: + new.original_def = previous + return True + else: + return False def update_function_type_variables(self, defn: FuncDef) -> None: """Make any type variables in the signature of defn explicit. @@ -362,8 +357,8 @@ def update_function_type_variables(self, defn: FuncDef) -> None: if defn is generic. """ if defn.type: - functype = cast(CallableType, defn.type) - typevars = self.infer_type_variables(functype) + assert isinstance(defn.type, CallableType) + typevars = self.infer_type_variables(defn.type) # Do not define a new type variable if already defined in scope. typevars = [(name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn)] @@ -373,7 +368,7 @@ def update_function_type_variables(self, defn: FuncDef) -> None: tvar[1].values, tvar[1].upper_bound, tvar[1].variance) for i, tvar in enumerate(typevars)] - functype.variables = defs + defn.type.variables = defs def infer_type_variables(self, type: CallableType) -> List[Tuple[str, TypeVarExpr]]: @@ -387,8 +382,7 @@ def infer_type_variables(self, tvars.append(tvar_expr) return list(zip(names, tvars)) - def find_type_variables_in_type( - self, type: Type) -> List[Tuple[str, TypeVarExpr]]: + def find_type_variables_in_type(self, type: Type) -> List[Tuple[str, TypeVarExpr]]: """Return a list of all unique type variable references in type. This effectively does partial name binding, results of which are mostly thrown away. @@ -398,7 +392,8 @@ def find_type_variables_in_type( name = type.name node = self.lookup_qualified(name, type) if node and node.kind == UNBOUND_TVAR: - result.append((name, cast(TypeVarExpr, node.node))) + assert isinstance(node.node, TypeVarExpr) + result.append((name, node.node)) for arg in type.args: result.extend(self.find_type_variables_in_type(arg)) elif isinstance(type, TypeList): @@ -425,8 +420,9 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: item.is_overload = True item.func.is_overload = True item.accept(self) - t.append(cast(CallableType, function_type(item.func, - self.builtin_type('builtins.function')))) + callable = function_type(item.func, self.builtin_type('builtins.function')) + assert isinstance(callable, CallableType) + t.append(callable) if item.func.is_property and i == 0: # This defines a property, probably with a setter and/or deleter. self.analyze_property_with_multi_part_definition(defn) @@ -524,8 +520,9 @@ def add_func_type_variables_to_symbol_table( nodes = [] # type: List[SymbolTableNode] if defn.type: tt = defn.type + assert isinstance(tt, CallableType) + items = tt.variables names = self.type_var_names() - items = cast(CallableType, tt).variables for item in items: name = item.name if name in names: @@ -549,7 +546,8 @@ def bind_type_var(self, fullname: str, tvar_def: TypeVarDef, return node def check_function_signature(self, fdef: FuncItem) -> None: - sig = cast(CallableType, fdef.type) + sig = fdef.type + assert isinstance(sig, CallableType) if len(sig.arg_types) < len(fdef.arguments): self.fail('Type signature has too few arguments', fdef) # Add dummy Any arguments to prevent crashes later. @@ -725,7 +723,8 @@ def analyze_unbound_tvar(self, t: Type) -> Tuple[str, TypeVarExpr]: unbound = t sym = self.lookup_qualified(unbound.name, unbound) if sym is not None and sym.kind == UNBOUND_TVAR: - return unbound.name, cast(TypeVarExpr, sym.node) + assert isinstance(sym.node, TypeVarExpr) + return unbound.name, sym.node return None def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: @@ -922,13 +921,15 @@ def class_type(self, info: TypeInfo) -> Type: def named_type(self, qualified_name: str, args: List[Type] = None) -> Instance: sym = self.lookup_qualified(qualified_name, None) - return Instance(cast(TypeInfo, sym.node), args or []) + assert isinstance(sym.node, TypeInfo) + return Instance(sym.node, args or []) def named_type_or_none(self, qualified_name: str, args: List[Type] = None) -> Instance: sym = self.lookup_fully_qualified_or_none(qualified_name) if not sym: return None - return Instance(cast(TypeInfo, sym.node), args or []) + assert isinstance(sym.node, TypeInfo) + return Instance(sym.node, args or []) def bind_class_type_variables_in_symbol_table( self, info: TypeInfo) -> List[SymbolTableNode]: @@ -1300,11 +1301,10 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False, lval.accept(self) elif (isinstance(lval, TupleExpr) or isinstance(lval, ListExpr)): - items = cast(Any, lval).items + items = lval.items if len(items) == 0 and isinstance(lval, TupleExpr): self.fail("Can't assign to ()", lval) - self.analyze_tuple_or_list_lvalue(cast(Union[ListExpr, TupleExpr], lval), - add_global, explicit_type) + self.analyze_tuple_or_list_lvalue(lval, add_global, explicit_type) elif isinstance(lval, StarExpr): if nested: self.analyze_lvalue(lval.expr, nested, add_global, explicit_type) @@ -1318,9 +1318,7 @@ def analyze_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr], explicit_type: bool = False) -> None: """Analyze an lvalue or assignment target that is a list or tuple.""" items = lval.items - star_exprs = [cast(StarExpr, item) - for item in items - if isinstance(item, StarExpr)] + star_exprs = [item for item in items if isinstance(item, StarExpr)] if len(star_exprs) > 1: self.fail('Two starred expressions in assignment', lval) @@ -1452,14 +1450,14 @@ def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Opt if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): self.fail("Argument 1 to NewType(...) must be a string literal", context) has_failed = True - elif cast(StrExpr, call.args[0]).value != name: + elif args[0].value != name: msg = "String argument 1 '{}' to NewType(...) does not match variable name '{}'" - self.fail(msg.format(cast(StrExpr, call.args[0]).value, name), context) + self.fail(msg.format(args[0].value, name), context) has_failed = True # Check second argument try: - unanalyzed_type = expr_to_unanalyzed_type(call.args[1]) + unanalyzed_type = expr_to_unanalyzed_type(args[1]) except TypeTranslationError: self.fail("Argument 2 to NewType(...) must be a valid type", context) return None @@ -1497,7 +1495,8 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> None: if not call: return - lvalue = cast(NameExpr, s.lvalues[0]) + lvalue = s.lvalues[0] + assert isinstance(lvalue, NameExpr) name = lvalue.name if not lvalue.is_def: if s.type: @@ -1538,9 +1537,9 @@ def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> boo or not call.arg_kinds[0] == ARG_POS): self.fail("TypeVar() expects a string literal as first argument", context) return False - if cast(StrExpr, call.args[0]).value != name: + elif call.args[0].value != name: msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'" - self.fail(msg.format(cast(StrExpr, call.args[0]).value, name), context) + self.fail(msg.format(call.args[0].value, name), context) return False return True @@ -2308,7 +2307,8 @@ def visit_member_expr(self, expr: MemberExpr) -> None: # This branch handles the case foo.bar where foo is a module. # In this case base.node is the module's MypyFile and we look up # bar in its namespace. This must be done for all types of bar. - file = cast(MypyFile, base.node) + file = base.node + assert isinstance(file, MypyFile) n = file.names.get(expr.name, None) if file is not None else None if n: n = self.normalize_type_alias(n, expr) @@ -2513,7 +2513,8 @@ def lookup(self, name: str, ctx: Context) -> SymbolTableNode: # 5. Builtins b = self.globals.get('__builtins__', None) if b: - table = cast(MypyFile, b.node).names + assert isinstance(b.node, MypyFile) + table = b.node.names if name in table: if name[0] == "_" and name[1] != "_": self.name_not_defined(name, ctx) @@ -2568,8 +2569,8 @@ def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode: def builtin_type(self, fully_qualified_name: str) -> Instance: node = self.lookup_fully_qualified(fully_qualified_name) - info = cast(TypeInfo, node.node) - return Instance(info, []) + assert isinstance(node.node, TypeInfo) + return Instance(node.node, []) def lookup_fully_qualified(self, name: str) -> SymbolTableNode: """Lookup a fully qualified name. @@ -2581,10 +2582,12 @@ def lookup_fully_qualified(self, name: str) -> SymbolTableNode: parts = name.split('.') n = self.modules[parts[0]] for i in range(1, len(parts) - 1): - n = cast(MypyFile, n.names[parts[i]].node) - return n.names[parts[-1]] + next_sym = n.names[parts[i]] + assert isinstance(next_sym.node, MypyFile) + n = next_sym.node + return n.names.get(parts[-1]) - def lookup_fully_qualified_or_none(self, name: str) -> SymbolTableNode: + def lookup_fully_qualified_or_none(self, name: str) -> Optional[SymbolTableNode]: """Lookup a fully qualified name. Assume that the name is defined. This happens in the global namespace -- the local @@ -2597,7 +2600,8 @@ def lookup_fully_qualified_or_none(self, name: str) -> SymbolTableNode: next_sym = n.names.get(parts[i]) if not next_sym: return None - n = cast(MypyFile, next_sym.node) + assert isinstance(next_sym.node, MypyFile) + n = next_sym.node return n.names.get(parts[-1]) def qualified_name(self, n: str) -> str: @@ -2811,11 +2815,7 @@ def visit_func_def(self, func: FuncDef) -> None: # Ah this is an imported name. We can't resolve them now, so we'll postpone # this until the main phase of semantic analysis. return - original_def = original_sym.node - if sem.is_conditional_func(original_def, func): - # Conditional function definition -- multiple defs are ok. - func.original_def = cast(FuncDef, original_def) - else: + if not sem.set_original_def(original_sym.node, func): # Report error. sem.check_no_global(func.name(), func) else: @@ -3055,10 +3055,11 @@ def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]: def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): return sig.copy_modified(arg_types=[new] + sig.arg_types[1:]) - else: - sig = cast(Overloaded, sig) + elif isinstance(sig, Overloaded): return Overloaded([cast(CallableType, replace_implicit_first_type(i, new)) for i in sig.items()]) + else: + assert False def set_callable_name(sig: Type, fdef: FuncDef) -> Type: