Skip to content

Commit

Permalink
Init obj_dict in Common, and use super()
Browse files Browse the repository at this point in the history
Give the Common class an `__init__() that sets up the `obj_dict`,
and call `super().__init__` from all of the subclasses. Makes the
code in the subclass `__init__`s simpler.
  • Loading branch information
ferdnyc committed Dec 6, 2024
1 parent cdb9037 commit aa86a71
Showing 1 changed file with 18 additions and 36 deletions.
54 changes: 18 additions & 36 deletions src/pydot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ class Common:
this one.
"""

def __init__(self, obj_dict: Optional[AttributeDict] = None) -> None:
self.obj_dict: AttributeDict = obj_dict or {}

def __getstate__(self) -> AttributeDict:
_dict = copy.copy(self.obj_dict)
return _dict
Expand Down Expand Up @@ -668,18 +671,8 @@ def __init__(
obj_dict: Optional[AttributeDict] = None,
**attrs: Any,
) -> None:
#
# Nodes will take attributes of
# all other types because the defaults
# for any GraphViz object are dealt with
# as if they were Node definitions
#
if obj_dict is not None:
self.obj_dict = obj_dict

else:
self.obj_dict = {}

super().__init__(obj_dict)
if obj_dict is None:
# Copy the attributes
#
self.obj_dict["attributes"] = dict(attrs)
Expand Down Expand Up @@ -783,21 +776,18 @@ def __init__(
obj_dict: Optional[AttributeDict] = None,
**attrs: Any,
) -> None:
self.obj_dict = {}
if isinstance(src, (Node, Subgraph, Cluster)):
src = src.get_name()
if isinstance(dst, (Node, Subgraph, Cluster)):
dst = dst.get_name()
points = (src, dst)
self.obj_dict["points"] = points
super().__init__(obj_dict)
if obj_dict is None:
# Copy the attributes
if isinstance(src, (Node, Subgraph, Cluster)):
src = src.get_name()
if isinstance(dst, (Node, Subgraph, Cluster)):
dst = dst.get_name()
points = (src, dst)
self.obj_dict["points"] = points
self.obj_dict["attributes"] = dict(attrs)
self.obj_dict["type"] = "edge"
self.obj_dict["parent_graph"] = None
self.obj_dict["sequence"] = None
else:
self.obj_dict = obj_dict

def __str__(self) -> str:
return self.to_string()
Expand Down Expand Up @@ -964,12 +954,8 @@ def __init__(
simplify: bool = False,
**attrs: Any,
) -> None:
if obj_dict is not None:
self.obj_dict = obj_dict

else:
self.obj_dict = {}

super().__init__(obj_dict)
if obj_dict is None:
self.obj_dict["attributes"] = dict(attrs)

if graph_type not in ["graph", "digraph"]:
Expand Down Expand Up @@ -1549,15 +1535,13 @@ def __init__(
simplify: bool = False,
**attrs: Any,
) -> None:
Graph.__init__(
self,
super().__init__(
graph_name=graph_name,
obj_dict=obj_dict,
suppress_disconnected=suppress_disconnected,
simplify=simplify,
**attrs,
)

if obj_dict is None:
self.obj_dict["type"] = "subgraph"

Expand Down Expand Up @@ -1602,15 +1586,13 @@ def __init__(
simplify: bool = False,
**attrs: Any,
) -> None:
Graph.__init__(
self,
super().__init__(
graph_name=graph_name,
obj_dict=obj_dict,
suppress_disconnected=suppress_disconnected,
simplify=simplify,
**attrs,
)

if obj_dict is None:
self.obj_dict["type"] = "subgraph"
self.obj_dict["name"] = quote_id_if_necessary(
Expand All @@ -1629,8 +1611,8 @@ class Dot(Graph):
the base class 'Graph'.
"""

def __init__(self, *argsl: Any, **argsd: Any) -> None:
Graph.__init__(self, *argsl, **argsd)
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.shape_files: List[str] = []
self.formats = OUTPUT_FORMATS
Expand Down

0 comments on commit aa86a71

Please sign in to comment.