Skip to content

Commit

Permalink
fix(typing): proper generic type for BinaryTreeNode
Browse files Browse the repository at this point in the history
 - pyright finally came calling on our swapping of left/right/parent fields in the MathExpression superclass.
  • Loading branch information
justindujardin committed Dec 15, 2023
1 parent 311240e commit 6ec236e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
8 changes: 4 additions & 4 deletions mathy_core/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
MathTypeKeysMax = max(MathTypeKeys.values()) + 1


class MathExpression(BinaryTreeNode):
class MathExpression(BinaryTreeNode["MathExpression"]):
"""Math tree node with helpers for manipulating expressions.
`mathy:x+y=z`
Expand Down Expand Up @@ -117,7 +117,7 @@ def visit_fn(
def with_color(self, text: str, style: str = "bright") -> str:
"""Render a string that is colored if something has changed"""
if self._rendering_change is True and self._changed is True:
return color(text, fore=self.color, style=style)
return f"{color(text, fore=self.color, style=style)}"
return text

def add_class(self, classes: Union[List[str], str]) -> "MathExpression":
Expand Down Expand Up @@ -147,7 +147,7 @@ def visit_fn(

def to_list(self, visit: str = "preorder") -> List["MathExpression"]:
"""Convert this node hierarchy into a list."""
results = []
results: List[MathExpression] = []

def visit_fn(
node: MathExpression, depth: int, data: Any
Expand Down Expand Up @@ -688,7 +688,7 @@ def clone(self) -> "ConstantExpression": # type:ignore[override]
result.value = self.value
return result # type:ignore

def evaluate(self, _context: Optional[Dict[str, NumberType]] = None) -> NumberType:
def evaluate(self, context: Optional[Dict[str, NumberType]] = None) -> NumberType:
assert self.value is not None
return self.value

Expand Down
28 changes: 14 additions & 14 deletions mathy_core/tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, TypeVar, Union, cast
from typing import Any, Callable, Generic, List, Optional, TypeVar, Union, cast

from .types import Literal

Expand Down Expand Up @@ -29,7 +29,7 @@
]


class BinaryTreeNode:
class BinaryTreeNode(Generic[NodeType]):
"""
The binary tree node is the base node for all of our trees, and provides a
rich set of methods for constructing, inspecting, and modifying them.
Expand All @@ -44,18 +44,18 @@ class BinaryTreeNode:
y: Optional[float]
offset: Optional[float]
level: Optional[int]
thread: Optional["BinaryTreeNode"]
thread: Optional[NodeType]

left: Optional["BinaryTreeNode"]
right: Optional["BinaryTreeNode"]
parent: Optional["BinaryTreeNode"]
left: Optional[NodeType]
right: Optional[NodeType]
parent: Optional[NodeType]

# Allow specifying children in the constructor
def __init__(
self,
left: Optional["BinaryTreeNode"] = None,
right: Optional["BinaryTreeNode"] = None,
parent: Optional["BinaryTreeNode"] = None,
self: NodeType,
left: Optional[NodeType] = None,
right: Optional[NodeType] = None,
parent: Optional[NodeType] = None,
id: Optional[str] = None,
):
if id is None:
Expand Down Expand Up @@ -208,7 +208,7 @@ def get_root(self: NodeType) -> NodeType:

return cast(NodeType, result)

def get_root_side(self: "BinaryTreeNode") -> Literal["left", "right"]:
def get_root_side(self: NodeType) -> Literal["left", "right"]:
"""Return the side of the tree that this node lives on"""
result = self
last_child = None
Expand All @@ -225,7 +225,7 @@ def get_root_side(self: "BinaryTreeNode") -> Literal["left", "right"]:

def set_left(
self: NodeType,
child: Optional["BinaryTreeNode"] = None,
child: Optional[NodeType] = None,
clear_old_child_parent: bool = False,
) -> NodeType:
"""Set the left node to the passed `child`"""
Expand All @@ -241,7 +241,7 @@ def set_left(

def set_right(
self: NodeType,
child: Optional["BinaryTreeNode"] = None,
child: Optional[NodeType] = None,
clear_old_child_parent: bool = False,
) -> NodeType:
"""Set the right node to the passed `child`"""
Expand All @@ -255,7 +255,7 @@ def set_right(

return self

def get_side(self, child: Optional["BinaryTreeNode"]) -> Literal["left", "right"]:
def get_side(self, child: Optional[NodeType]) -> Literal["left", "right"]:
"""Determine whether the given `child` is the left or right child of this
node"""
if child == self.left:
Expand Down

0 comments on commit 6ec236e

Please sign in to comment.