Skip to content

Commit

Permalink
Make an alias for tree branches
Browse files Browse the repository at this point in the history
Branch must be used as a generic so that mypy understands that _Leaf_T is the same type associated with the Tree
  • Loading branch information
plannigan committed Dec 19, 2021
1 parent 8a8b70d commit acf2a26
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions lark/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self):


_Leaf_T = TypeVar("_Leaf_T")
Branch = Union[_Leaf_T, 'Tree[_Leaf_T]']


class Tree(Generic[_Leaf_T]):
Expand All @@ -52,9 +53,9 @@ class Tree(Generic[_Leaf_T]):
"""

data: str
children: 'List[Union[_Leaf_T, Tree[_Leaf_T]]]'
children: 'List[Branch[_Leaf_T]]'

def __init__(self, data: str, children: 'List[Union[_Leaf_T, Tree[_Leaf_T]]]', meta: Optional[Meta]=None) -> None:
def __init__(self, data: str, children: 'List[Branch[_Leaf_T]]', meta: Optional[Meta]=None) -> None:
self.data = data
self.children = children
self._meta = meta
Expand Down Expand Up @@ -135,7 +136,7 @@ def expand_kids_by_index(self, *indices: int) -> None:
kid = self.children[i]
self.children[i:i+1] = kid.children

def scan_values(self, pred: 'Callable[[Union[_Leaf_T, Tree[_Leaf_T]]], bool]') -> Iterator[_Leaf_T]:
def scan_values(self, pred: 'Callable[[Branch[_Leaf_T]], bool]') -> Iterator[_Leaf_T]:
"""Return all values in the tree that evaluate pred(value) as true.
This can be used to find all the tokens in the tree.
Expand Down Expand Up @@ -171,7 +172,7 @@ def __deepcopy__(self, memo):
def copy(self) -> 'Tree[_Leaf_T]':
return type(self)(self.data, self.children)

def set(self, data: str, children: 'List[Union[_Leaf_T, Tree[_Leaf_T]]]') -> None:
def set(self, data: str, children: 'List[Branch[_Leaf_T]]') -> None:
self.data = data
self.children = children

Expand Down
4 changes: 2 additions & 2 deletions lark/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import wraps

from .utils import smart_decorator, combine_alternatives
from .tree import Tree
from .tree import Tree, Branch
from .exceptions import VisitError, GrammarError
from .lexer import Token

Expand Down Expand Up @@ -206,7 +206,7 @@ class Transformer_NonRecursive(Transformer):
def transform(self, tree: Tree[_Leaf_T]) -> _Return_T:
# Tree to postfix
rev_postfix = []
q: List[Union[_Leaf_T, Tree[_Leaf_T]]] = [tree]
q: List[Branch[_Leaf_T]] = [tree]
while q:
t = q.pop()
rev_postfix.append(t)
Expand Down

0 comments on commit acf2a26

Please sign in to comment.