class LocalNamesVisitor(TraverserVisitor): """ Scans a function definition, looking for all names which create bindings. Does not descend into function or class defs. """ def visit(self, f: FuncDef) -> None: """ Adds the binding names found in given func def, storing them in a list in f.locals. """ self.names: List[str] = [] self.bind: bool = False # By default, names found will not be bindings. self.recursive: bool = False # True if visiting a nested function or class. # Use base class so that it will traverse the function and not record this function's name. super().visit_func_def(f) self.names += f.arg_names f.locals = sorted(set(self.names)) del self.names, self.bind # All nodes in the tree, except for class and function defs, will be fully traversed. # Any assignment expressions will create bindings for the lvalue names. # Some nodes can create bindings for certain names found therein. # Only simple names are considered. Member names are ignored. @contextmanager def binding_context(self) -> None: """ Within the context, all simple names (not member names) are bindings. """ old = self.bind self.bind = True yield self.bind = old def bind_names(self, *names: str): for name in names: self.names.append(name) def bind_names_in(self, *nodes: Optional[Node]): """ Bind all names found in these Nodes. """ with self.binding_context(): for node in nodes: if node is not None: node.accept(self) @contextmanager def recursive_context(self) -> None: """ Within the context, nodes are contained in a function or class def. The function or class body will be skipped, other nodes will be traversed. """ old = self.recursive self.recursive = True yield self.recursive = old def visit_block(self, o: Block) -> None: if self.recursive: return super().visit_block(o) def visit_assignment_expr(self, o: AssignmentExpr) -> None: """ target := expr binds target, even if within a generator or comprehension. """ self.bind_names_in(o.target) def visit_assignment_stmt(self, o: AssignmentStmt) -> None: """ lvalue (= lvalue ...) = expr binds each lvalue. """ self.bind_names_in(*o.lvalues) def visit_name_expr(self, o: NameExpr) -> None: """ Binds simple name, if in a binding context. """ if self.bind: self.bind_names(o.name) def visit_func_def(self, o: FuncDef) -> None: """ Binds function name. Traverses everything other than the function body. """ self.bind_names_in(o.name) with self.recursive_context(): super().visit_func_def(o) def visit_class_def(self, o: ClassDef) -> None: """ Binds the class name. Traverses everything other than the function body. """ self.bind_names_in(o.name) with self.recursive_context(): super().visit_class_def(o) def visit_del_stmt(self, o: DelStmt) -> None: self.bind_names_in(o.expr) def visit_for_stmt(self, o: ForStmt) -> None: self.bind_names_in(o.index) def visit_import(self, o: Import) -> None: self.bind_names(*o.ids) def visit_import_from(self, o: ImportFrom) -> None: self.bind_names(*( as_id or id for id, as_id in o.names ) ) def visit_operator_assignment_stmt(self, o: OperatorAssignmentStmt) -> None: self.bind_names_in(o.lvalue) def visit_try_stmt(self, o: TryStmt) -> None: self.bind_names_in(*o.vars) def visit_with_stmt(o: WithStmt) -> None: self.bind_names_in(*o.target) # Match statements and case patterns: # Also, bindings are created by: # The name in an AS pattern. # The rest name in a mapping pattern (i.e. **x). # The capture name in a starred pattern (i.e. [..., *x] in a sequence pattern). def visit_as_pattern(self, o: AsPattern) -> None: if o.pattern is not None: o.pattern.accept(self) if o.name is not None: self.bind_names_in(o.name) def visit_starred_patten(self, o: StarredPattern) -> None: if o.capture is not None: self.bind_names_in(o.capture) def visit_mapping_pattern(self, o: MappingPattern) -> None: # Keys are just names or literals. Values may have bindings. Rest is a binding. for value in o.values: value.accept(self) if o.rest is not None: self.bind_names_in(o.rest)