Skip to content

Commit

Permalink
refactor: move organize_edge_ids to MutableTopology
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Feb 21, 2022
1 parent cf9efab commit c6ccd6f
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 111 deletions.
125 changes: 69 additions & 56 deletions src/qrules/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,31 +195,34 @@ class Topology:

def __attrs_post_init__(self) -> None:
self.__verify()
object.__setattr__(
self,
"incoming_edge_ids",
frozenset(
edge_id
for edge_id, edge in self.edges.items()
if edge.originating_node_id is None
),
)
object.__setattr__(
self,
"outgoing_edge_ids",
frozenset(
edge_id
for edge_id, edge in self.edges.items()
if edge.ending_node_id is None
),
incoming = sorted(
edge_id
for edge_id, edge in self.edges.items()
if edge.originating_node_id is None
)
object.__setattr__(
self,
"intermediate_edge_ids",
frozenset(self.edges)
^ self.incoming_edge_ids
^ self.outgoing_edge_ids,
outgoing = sorted(
edge_id
for edge_id, edge in self.edges.items()
if edge.ending_node_id is None
)
inter = sorted(set(self.edges) - set(incoming) - set(outgoing))
expected = list(range(-len(incoming), 0))
if sorted(incoming) != expected:
raise ValueError(
f"Incoming edge IDs should be {expected}, not {incoming}."
)
n_out = len(outgoing)
expected = list(range(0, n_out))
if sorted(outgoing) != expected:
raise ValueError(
f"Outgoing edge IDs should be {expected}, not {outgoing}."
)
expected = list(range(n_out, n_out + len(inter)))
if sorted(inter) != expected:
raise ValueError(f"Intermediate edge IDs should be {expected}.")
object.__setattr__(self, "incoming_edge_ids", frozenset(incoming))
object.__setattr__(self, "outgoing_edge_ids", frozenset(outgoing))
object.__setattr__(self, "intermediate_edge_ids", frozenset(inter))

def __verify(self) -> None:
"""Verify if there are no dangling edges or nodes."""
Expand Down Expand Up @@ -314,29 +317,7 @@ def get_originating_initial_state_edge_ids(self, node_id: int) -> Set[int]:
temp_edge_list = new_temp_edge_list
return edge_ids

def organize_edge_ids(self) -> "Topology":
"""Create a new topology with edge IDs in range :code:`[-m, n+i]`.
where :code:`m` is the number of `.incoming_edge_ids`, :code:`n` is the
number of `.outgoing_edge_ids`, and :code:`i` is the number of
`.intermediate_edge_ids`.
In other words, relabel the edges so that:
- `.incoming_edge_ids` lies in the range :code:`[-1, -2, ...]`
- `.outgoing_edge_ids` lies in the range :code:`[0, 1, ..., n]`
- `.intermediate_edge_ids` lies in the range :code:`[n+1, n+2, ...]`
"""
new_to_old_id = enumerate(
tuple(self.incoming_edge_ids)
+ tuple(self.outgoing_edge_ids)
+ tuple(self.intermediate_edge_ids),
start=-len(self.incoming_edge_ids),
)
old_to_new_id = {j: i for i, j in new_to_old_id}
return self.relabel_edges(old_to_new_id)

def relabel_edges(self, old_to_new_id: Mapping[int, int]) -> "Topology":
def relabel_edges(self, old_to_new: Mapping[int, int]) -> "Topology":
"""Create a new `Topology` with new edge IDs.
This method is particularly useful when creating permutations of a
Expand All @@ -354,8 +335,10 @@ def relabel_edges(self, old_to_new_id: Mapping[int, int]) -> "Topology":
>>> len(permuted_topologies)
3
"""
new_to_old = {j: i for i, j in old_to_new.items()}
new_edges = {
old_to_new_id.get(i, i): edge for i, edge in self.edges.items()
old_to_new.get(i, new_to_old.get(i, i)): edge
for i, edge in self.edges.items()
}
return attrs.evolve(self, edges=new_edges)

Expand Down Expand Up @@ -405,12 +388,6 @@ class MutableTopology:
),
)

def freeze(self) -> Topology:
return Topology(
edges=self.edges,
nodes=self.nodes,
)

def add_node(self, node_id: int) -> None:
"""Adds a node nr. node_id.
Expand Down Expand Up @@ -482,6 +459,43 @@ def attach_edges_to_node_outgoing(
originating_node_id=node_id,
)

def organize_edge_ids(self) -> "MutableTopology":
"""Organize edge IDS so that they lie in range :code:`[-m, n+i]`.
Here, :code:`m` is the number of `.incoming_edge_ids`, :code:`n` is the
number of `.outgoing_edge_ids`, and :code:`i` is the number of
`.intermediate_edge_ids`.
In other words, relabel the edges so that:
- incoming edge IDs lie in the range :code:`[-1, -2, ...]`,
- outgoing edge IDs lie in the range :code:`[0, 1, ..., n]`,
- intermediate edge IDs lie in the range :code:`[n+1, n+2, ...]`.
"""
incoming = {
i
for i, edge in self.edges.items()
if edge.originating_node_id is None
}
outgoing = {
edge_id
for edge_id, edge in self.edges.items()
if edge.ending_node_id is None
}
intermediate = set(self.edges) - incoming - outgoing
new_to_old_id = enumerate(
list(incoming) + list(outgoing) + list(intermediate),
start=-len(incoming),
)
old_to_new_id = {j: i for i, j in new_to_old_id}
new_edges = {
old_to_new_id.get(i, i): edge for i, edge in self.edges.items()
}
return attrs.evolve(self, edges=new_edges)

def freeze(self) -> Topology:
return Topology(self.nodes, self.edges)


@define
class InteractionNode:
Expand Down Expand Up @@ -546,7 +560,6 @@ def build(
len(active_graph[1]) == number_of_final_edges
and len(active_graph[0].nodes) > 0
):
active_graph[0].freeze() # verify
graph_tuple_list.append(active_graph)
continue

Expand All @@ -556,9 +569,9 @@ def build(
# strip the current open end edges list from the result graph tuples
topologies = []
for graph_tuple in graph_tuple_list:
topology = graph_tuple[0].freeze()
topology = graph_tuple[0]
topology = topology.organize_edge_ids()
topologies.append(topology)
topologies.append(topology.freeze())
return tuple(topologies)

def _extend_graph(
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/io/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ def test_write_topology(self, output_dir):
output_file = output_dir + "two_body_decay_topology.gv"
topology = Topology(
nodes={0},
edges={0: Edge(0, None), 1: Edge(None, 0), 2: Edge(None, 0)},
edges={
-1: Edge(None, 0),
0: Edge(0, None),
1: Edge(0, None),
},
)
io.write(
instance=topology,
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_system_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ def make_ls_test_graph(
):
topology = Topology(
nodes={0},
edges={0: Edge(None, 0)},
edges={-1: Edge(None, 0)},
)
interactions = {
0: InteractionProperties(
s_magnitude=coupled_spin_magnitude,
l_magnitude=angular_momentum_magnitude,
)
}
states: Dict[int, ParticleWithSpin] = {0: (particle, 0)}
states: Dict[int, ParticleWithSpin] = {-1: (particle, 0)}
graph = MutableTransition(topology, states, interactions)
return graph

Expand All @@ -228,15 +228,15 @@ def make_ls_test_graph_scrambled(
):
topology = Topology(
nodes={0},
edges={0: Edge(None, 0)},
edges={-1: Edge(None, 0)},
)
interactions = {
0: InteractionProperties(
l_magnitude=angular_momentum_magnitude,
s_magnitude=coupled_spin_magnitude,
)
}
states: Dict[int, ParticleWithSpin] = {0: (particle, 0)}
states: Dict[int, ParticleWithSpin] = {-1: (particle, 0)}
graph = MutableTransition(topology, states, interactions)
return graph

Expand Down Expand Up @@ -326,7 +326,7 @@ def test_filter_graphs_for_interaction_qns(
tempgraph = attrs.evolve(
tempgraph,
states={
0: (
-1: (
Particle(name=value[0], pid=0, mass=1.0, spin=1.0),
0.0,
)
Expand Down
Loading

0 comments on commit c6ccd6f

Please sign in to comment.