From aa86a71ede716269bdf5194d3f058e638bf58fbc Mon Sep 17 00:00:00 2001 From: "FeRD (Frank Dana)" Date: Fri, 6 Dec 2024 04:18:19 -0500 Subject: [PATCH] Init obj_dict in Common, and use super() Give the Common class an `__init__() that sets up the `obj_dict`, and call `super().__init__` from all of the subclasses. Makes the code in the subclass `__init__`s simpler. --- src/pydot/core.py | 54 ++++++++++++++++------------------------------- 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/src/pydot/core.py b/src/pydot/core.py index d357428..73ccc46 100644 --- a/src/pydot/core.py +++ b/src/pydot/core.py @@ -553,6 +553,9 @@ class Common: this one. """ + def __init__(self, obj_dict: Optional[AttributeDict] = None) -> None: + self.obj_dict: AttributeDict = obj_dict or {} + def __getstate__(self) -> AttributeDict: _dict = copy.copy(self.obj_dict) return _dict @@ -668,18 +671,8 @@ def __init__( obj_dict: Optional[AttributeDict] = None, **attrs: Any, ) -> None: - # - # Nodes will take attributes of - # all other types because the defaults - # for any GraphViz object are dealt with - # as if they were Node definitions - # - if obj_dict is not None: - self.obj_dict = obj_dict - - else: - self.obj_dict = {} - + super().__init__(obj_dict) + if obj_dict is None: # Copy the attributes # self.obj_dict["attributes"] = dict(attrs) @@ -783,21 +776,18 @@ def __init__( obj_dict: Optional[AttributeDict] = None, **attrs: Any, ) -> None: - self.obj_dict = {} - if isinstance(src, (Node, Subgraph, Cluster)): - src = src.get_name() - if isinstance(dst, (Node, Subgraph, Cluster)): - dst = dst.get_name() - points = (src, dst) - self.obj_dict["points"] = points + super().__init__(obj_dict) if obj_dict is None: - # Copy the attributes + if isinstance(src, (Node, Subgraph, Cluster)): + src = src.get_name() + if isinstance(dst, (Node, Subgraph, Cluster)): + dst = dst.get_name() + points = (src, dst) + self.obj_dict["points"] = points self.obj_dict["attributes"] = dict(attrs) self.obj_dict["type"] = "edge" self.obj_dict["parent_graph"] = None self.obj_dict["sequence"] = None - else: - self.obj_dict = obj_dict def __str__(self) -> str: return self.to_string() @@ -964,12 +954,8 @@ def __init__( simplify: bool = False, **attrs: Any, ) -> None: - if obj_dict is not None: - self.obj_dict = obj_dict - - else: - self.obj_dict = {} - + super().__init__(obj_dict) + if obj_dict is None: self.obj_dict["attributes"] = dict(attrs) if graph_type not in ["graph", "digraph"]: @@ -1549,15 +1535,13 @@ def __init__( simplify: bool = False, **attrs: Any, ) -> None: - Graph.__init__( - self, + super().__init__( graph_name=graph_name, obj_dict=obj_dict, suppress_disconnected=suppress_disconnected, simplify=simplify, **attrs, ) - if obj_dict is None: self.obj_dict["type"] = "subgraph" @@ -1602,15 +1586,13 @@ def __init__( simplify: bool = False, **attrs: Any, ) -> None: - Graph.__init__( - self, + super().__init__( graph_name=graph_name, obj_dict=obj_dict, suppress_disconnected=suppress_disconnected, simplify=simplify, **attrs, ) - if obj_dict is None: self.obj_dict["type"] = "subgraph" self.obj_dict["name"] = quote_id_if_necessary( @@ -1629,8 +1611,8 @@ class Dot(Graph): the base class 'Graph'. """ - def __init__(self, *argsl: Any, **argsd: Any) -> None: - Graph.__init__(self, *argsl, **argsd) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self.shape_files: List[str] = [] self.formats = OUTPUT_FORMATS