From 37f1b94bf99fef87fe703303e508dd2346260ca0 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:29 +0100 Subject: [PATCH 01/34] refactor: make MutableTopology public --- src/qrules/topology.py | 18 +++++++++--------- tests/unit/test_topology.py | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index b8e6b769..eea03f10 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -377,7 +377,7 @@ def __get_originating_node(edge_id: int) -> Optional[int]: @define(kw_only=True) -class _MutableTopology: +class MutableTopology: edges: Dict[int, Edge] = field(factory=dict, converter=dict) nodes: Set[int] = field(factory=set, converter=set) @@ -506,12 +506,12 @@ def build( logging.info("building topology graphs...") # result list - graph_tuple_list: List[Tuple[_MutableTopology, List[int]]] = [] + graph_tuple_list: List[Tuple[MutableTopology, List[int]]] = [] # create seed graph - seed_graph = _MutableTopology() + seed_graph = MutableTopology() current_open_end_edges = list(range(number_of_initial_edges)) seed_graph.add_edges(current_open_end_edges) - extendable_graph_list: List[Tuple[_MutableTopology, List[int]]] = [ + extendable_graph_list: List[Tuple[MutableTopology, List[int]]] = [ (seed_graph, current_open_end_edges) ] @@ -540,9 +540,9 @@ def build( return tuple(topologies) def _extend_graph( - self, pair: Tuple[_MutableTopology, Sequence[int]] - ) -> List[Tuple[_MutableTopology, List[int]]]: - extended_graph_list: List[Tuple[_MutableTopology, List[int]]] = [] + self, pair: Tuple[MutableTopology, Sequence[int]] + ) -> List[Tuple[MutableTopology, List[int]]]: + extended_graph_list: List[Tuple[MutableTopology, List[int]]] = [] topology, current_open_end_edges = pair @@ -619,10 +619,10 @@ def create_n_body_topology( def _attach_node_to_edges( - graph: Tuple[_MutableTopology, Sequence[int]], + graph: Tuple[MutableTopology, Sequence[int]], interaction_node: InteractionNode, ingoing_edge_ids: Iterable[int], -) -> Tuple[_MutableTopology, List[int]]: +) -> Tuple[MutableTopology, List[int]]: temp_graph = copy.deepcopy(graph[0]) new_open_end_lines = list(copy.deepcopy(graph[1])) diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 688ab640..bc2d8306 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -10,9 +10,9 @@ Edge, FrozenDict, InteractionNode, + MutableTopology, SimpleStateTransitionTopologyBuilder, Topology, - _MutableTopology, create_isobar_topologies, create_n_body_topology, get_originating_node_list, @@ -100,7 +100,7 @@ def test_constructor_exceptions(self): class TestMutableTopology: def test_add_and_attach(self, two_to_three_decay: Topology): - topology = _MutableTopology( + topology = MutableTopology( edges=two_to_three_decay.edges, nodes=two_to_three_decay.nodes, # type: ignore[arg-type] ) @@ -116,7 +116,7 @@ def test_add_and_attach(self, two_to_three_decay: Topology): assert isinstance(topology.freeze(), Topology) def test_add_exceptions(self, two_to_three_decay: Topology): - topology = _MutableTopology( + topology = MutableTopology( edges=two_to_three_decay.edges, nodes=two_to_three_decay.nodes, # type: ignore[arg-type] ) From fe2adc2647035d9c044d79c8caf1264cd2a69eb0 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:29 +0100 Subject: [PATCH 02/34] refactor: remove kw_only from MutableTopology --- src/qrules/topology.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index eea03f10..a5e33c3d 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -376,7 +376,7 @@ def __get_originating_node(edge_id: int) -> Optional[int]: ] -@define(kw_only=True) +@define class MutableTopology: edges: Dict[int, Edge] = field(factory=dict, converter=dict) nodes: Set[int] = field(factory=set, converter=set) From 113df7aea2d79aebce8117a6b483dff02235a2b5 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:30 +0100 Subject: [PATCH 03/34] chore: simplify (over)defined_assert functions --- src/qrules/topology.py | 24 +++++++++++++++++------- src/qrules/transition.py | 17 +++-------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index a5e33c3d..2cf0ba9d 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -19,7 +19,6 @@ TYPE_CHECKING, Any, Callable, - Collection, Dict, FrozenSet, Generic, @@ -678,8 +677,8 @@ def __init__( self.topology = topology def __post_init__(self) -> None: - _assert_over_defined(self.topology.nodes, self.__node_props) - _assert_over_defined(self.topology.edges, self.__edge_props) + _assert_not_overdefined(self.topology.nodes, self.__node_props) + _assert_not_overdefined(self.topology.edges, self.__edge_props) def __eq__(self, other: object) -> bool: """Check if two `.StateTransitionGraph` instances are **identical**.""" @@ -716,13 +715,13 @@ def evolve( """ new_node_props = copy.copy(self.__node_props) if node_props: - _assert_over_defined(self.topology.nodes, node_props) + _assert_not_overdefined(self.topology.nodes, node_props) for node_id, node_prop in node_props.items(): new_node_props[node_id] = node_prop new_edge_props = copy.copy(self.__edge_props) if edge_props: - _assert_over_defined(self.topology.edges, edge_props) + _assert_not_overdefined(self.topology.edges, edge_props) for edge_id, edge_prop in edge_props.items(): new_edge_props[edge_id] = edge_prop @@ -770,10 +769,21 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: self.__edge_props[edge_id1] = value2 -def _assert_over_defined(items: Collection, properties: Mapping) -> None: +# pyright: reportUnusedFunction=false +def _assert_all_defined(items: Iterable, properties: Iterable) -> None: + existing = set(items) defined = set(properties) + if existing & defined != existing: + raise ValueError( + "Some items have no property assigned to them." + f" Available items: {existing}, items with property: {defined}" + ) + + +def _assert_not_overdefined(items: Iterable, properties: Iterable) -> None: existing = set(items) - over_defined = existing & defined ^ defined + defined = set(properties) + over_defined = defined - existing if over_defined: raise ValueError( "Properties have been defined for items that don't exist." diff --git a/src/qrules/transition.py b/src/qrules/transition.py index ff671ab8..5eda4ff3 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -6,11 +6,9 @@ from enum import Enum, auto from multiprocessing import Pool from typing import ( - Collection, Dict, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -76,6 +74,7 @@ FrozenDict, StateTransitionGraph, Topology, + _assert_all_defined, create_isobar_topologies, create_n_body_topology, ) @@ -742,8 +741,8 @@ class StateTransition: ) def __attrs_post_init__(self) -> None: - _assert_defined(self.topology.edges, self.states) - _assert_defined(self.topology.nodes, self.interactions) + _assert_all_defined(self.topology.edges, self.states) + _assert_all_defined(self.topology.nodes, self.interactions) @staticmethod def from_graph( @@ -792,16 +791,6 @@ def particles(self) -> Dict[int, Particle]: return {i: edge_prop.particle for i, edge_prop in self.states.items()} -def _assert_defined(items: Collection, properties: Mapping) -> None: - existing = set(items) - defined = set(properties) - if existing & defined != existing: - raise ValueError( - "Some items have no property assigned to them." - f" Available items: {existing}, items with property: {defined}" - ) - - def _sort_tuple( iterable: Iterable[StateTransition], ) -> Tuple[StateTransition, ...]: From accc555fa76793e4fb90997d38fe0c3239289394 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:30 +0100 Subject: [PATCH 04/34] chore: rewrite StateTransitionGraph with attrs WARNING: this removes the __post_init__ check, but this didn't seem to be doing anything anyway. Initially, I simply renamed the check __attrs_post_init__, but this actually does perform the check and then the system crashes. This seems to be an issue that needs to be addressed separately. --- src/qrules/topology.py | 61 ++++++++++++------------------------------ 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 2cf0ba9d..fe77cc49 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -655,6 +655,8 @@ def _attach_node_to_edges( """A `~typing.TypeVar` representing the type of edge properties.""" +@implement_pretty_repr() +@define class StateTransitionGraph(Generic[EdgeType]): """Graph class that resembles a frozen `.Topology` with properties. @@ -664,44 +666,15 @@ class StateTransitionGraph(Generic[EdgeType]): error can be raised on property retrieval. """ - def __init__( - self, - topology: Topology, - node_props: Mapping[int, InteractionProperties], - edge_props: Mapping[int, EdgeType], - ): - self.__node_props = dict(node_props) - self.__edge_props = dict(edge_props) - if not isinstance(topology, Topology): - raise TypeError - self.topology = topology - - def __post_init__(self) -> None: - _assert_not_overdefined(self.topology.nodes, self.__node_props) - _assert_not_overdefined(self.topology.edges, self.__edge_props) - - def __eq__(self, other: object) -> bool: - """Check if two `.StateTransitionGraph` instances are **identical**.""" - if isinstance(other, StateTransitionGraph): - if self.topology != other.topology: - return False - for i in self.topology.edges: - if self.get_edge_props(i) != other.get_edge_props(i): - return False - for i in self.topology.nodes: - if self.get_node_props(i) != other.get_node_props(i): - return False - return True - raise NotImplementedError( - f"Cannot compare {self.__class__.__name__}" - f" with {other.__class__.__name__}" - ) + topology: Topology = field(validator=instance_of(Topology)) + node_props: Dict[int, InteractionProperties] + edge_props: Dict[int, EdgeType] def get_node_props(self, node_id: int) -> InteractionProperties: - return self.__node_props[node_id] + return self.node_props[node_id] def get_edge_props(self, edge_id: int) -> EdgeType: - return self.__edge_props[edge_id] + return self.edge_props[edge_id] def evolve( self, @@ -713,13 +686,13 @@ def evolve( Since a `.StateTransitionGraph` is frozen (cannot be modified), the evolve function will also create a shallow copy the properties. """ - new_node_props = copy.copy(self.__node_props) + new_node_props = copy.copy(self.node_props) if node_props: _assert_not_overdefined(self.topology.nodes, node_props) for node_id, node_prop in node_props.items(): new_node_props[node_id] = node_prop - new_edge_props = copy.copy(self.__edge_props) + new_edge_props = copy.copy(self.edge_props) if edge_props: _assert_not_overdefined(self.topology.edges, edge_props) for edge_id, edge_prop in edge_props.items(): @@ -744,13 +717,13 @@ def compare( if edge_comparator is not None: for i in self.topology.edges: if not edge_comparator( - self.get_edge_props(i), other.get_edge_props(i) + self.edge_props[i], other.edge_props[i] ): return False if node_comparator is not None: for i in self.topology.nodes: if not node_comparator( - self.get_node_props(i), other.get_node_props(i) + self.node_props[i], other.node_props[i] ): return False return True @@ -759,14 +732,14 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: self.topology = self.topology.swap_edges(edge_id1, edge_id2) value1: Optional[EdgeType] = None value2: Optional[EdgeType] = None - if edge_id1 in self.__edge_props: - value1 = self.__edge_props.pop(edge_id1) - if edge_id2 in self.__edge_props: - value2 = self.__edge_props.pop(edge_id2) + if edge_id1 in self.edge_props: + value1 = self.edge_props.pop(edge_id1) + if edge_id2 in self.edge_props: + value2 = self.edge_props.pop(edge_id2) if value1 is not None: - self.__edge_props[edge_id2] = value1 + self.edge_props[edge_id2] = value1 if value2 is not None: - self.__edge_props[edge_id1] = value2 + self.edge_props[edge_id1] = value2 # pyright: reportUnusedFunction=false From b350fe144052d70001864772a06f63fe825d4305 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:31 +0100 Subject: [PATCH 05/34] ci: disable pylint line-too-long --- .pylintrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.pylintrc b/.pylintrc index c32808b8..fa0c26f6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -16,6 +16,7 @@ ignore-patterns= disable= duplicate-code, # https://github.com/PyCQA/pylint/issues/214 invalid-unary-operand-type, # conflicts with attrs.field + line-too-long, # automatically fixed with black logging-fstring-interpolation, missing-class-docstring, # pydocstyle missing-function-docstring, # pydocstyle From 2a1a2613672c3b2dbeebf242da989a880447c635 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:31 +0100 Subject: [PATCH 06/34] refactor: remove get_edge/node_props() --- src/qrules/_system_control.py | 6 +++--- src/qrules/combinatorics.py | 6 +----- src/qrules/io/_dict.py | 4 ++-- src/qrules/io/_dot.py | 23 +++++++++++------------ src/qrules/topology.py | 20 ++++++++++++-------- src/qrules/transition.py | 13 +++++-------- tests/unit/io/test_dot.py | 28 +++++++++++++--------------- tests/unit/test_system_control.py | 8 ++++---- 8 files changed, 51 insertions(+), 57 deletions(-) diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index 3b9a2c3e..acae3754 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -233,7 +233,7 @@ def _remove_qns_from_graph( # pylint: disable=too-many-branches ) -> StateTransitionGraph[ParticleWithSpin]: new_node_props = {} for node_id in graph.topology.nodes: - node_props = graph.get_node_props(node_id) + node_props = graph.node_props[node_id] new_node_props[node_id] = attrs.evolve( node_props, **{x.__name__: None for x in qn_list} ) @@ -359,7 +359,7 @@ def check(graph: StateTransitionGraph[ParticleWithSpin]) -> bool: return False for i in node_ids: if ( - getattr(graph.get_node_props(i), interaction_qn.__name__) + getattr(graph.node_props[i], interaction_qn.__name__) not in allowed_values ): return False @@ -375,7 +375,7 @@ def _find_node_ids_with_ingoing_particle_name( found_node_ids = [] for node_id in topology.nodes: for edge_id in topology.get_edge_ids_ingoing_to_node(node_id): - edge_props = graph.get_edge_props(edge_id) + edge_props = graph.edge_props[edge_id] edge_particle_name = edge_props[0].name if str(ingoing_particle_name) in str(edge_particle_name): found_node_ids.append(node_id) diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 372ca900..5674d005 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -559,8 +559,4 @@ def _calculate_swappings(id_mapping: Dict[int, int]) -> OrderedDict: def _create_edge_id_particle_mapping( graph: StateTransitionGraph[ParticleWithSpin], edge_ids: Iterable[int] ) -> Dict[int, str]: - return { - i: graph.get_edge_props(i)[0].name - for i in edge_ids - if graph.get_edge_props(i) - } + return {i: graph.edge_props[i][0].name for i in edge_ids} diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index f1d81c22..4746b977 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -37,7 +37,7 @@ def from_stg(graph: StateTransitionGraph[ParticleWithSpin]) -> dict: topology = graph.topology edge_props_def = {} for i in topology.edges: - particle, spin_projection = graph.get_edge_props(i) + particle, spin_projection = graph.edge_props[i] if isinstance(spin_projection, float) and spin_projection.is_integer(): spin_projection = int(spin_projection) edge_props_def[i] = { @@ -46,7 +46,7 @@ def from_stg(graph: StateTransitionGraph[ParticleWithSpin]) -> dict: } node_props_def = {} for i in topology.nodes: - node_prop = graph.get_node_props(i) + node_prop = graph.node_props[i] node_props_def[i] = attrs.asdict( node_prop, filter=lambda a, v: a.init and a.default != v ) diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index c3d57583..c63ecfdd 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -295,7 +295,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals int, InteractionProperties ] = graph.interactions else: - interactions = {i: graph.get_node_props(i) for i in topology.nodes} + interactions = {i: graph.node_props[i] for i in topology.nodes} for node_id, node_prop in interactions.items(): node_label = "" if render_node: @@ -361,7 +361,7 @@ def __get_edge_label( if isinstance(graph, StateTransition): graph = graph.to_graph() if isinstance(graph, StateTransitionGraph): - edge_prop = graph.get_edge_props(edge_id) + edge_prop = graph.edge_props[edge_id] return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) if isinstance(graph, Topology): if render_edge_id: @@ -491,7 +491,7 @@ def _get_particle_graphs( inventory = sorted( inventory, key=lambda g: [ - g.get_edge_props(i).mass for i in g.topology.intermediate_edge_ids + g.edge_props[i].mass for i in g.topology.intermediate_edge_ids ], ) return inventory @@ -504,12 +504,12 @@ def _strip_projections( graph = graph.to_graph() new_edge_props = {} for edge_id in graph.topology.edges: - edge_props = graph.get_edge_props(edge_id) + edge_props = graph.edge_props[edge_id] if edge_props: new_edge_props[edge_id] = edge_props[0] new_node_props = {} for node_id in graph.topology.nodes: - node_props = graph.get_node_props(node_id) + node_props = graph.node_props[node_id] if node_props: new_node_props[node_id] = attrs.evolve( node_props, l_projection=None, s_projection=None @@ -536,8 +536,8 @@ def merge_into( "Cannot merge graphs that don't have the same edge IDs" ) for i in graph.topology.edges: - particle = graph.get_edge_props(i) - other_particles = merged_graph.get_edge_props(i) + particle = graph.edge_props[i] + other_particles = merged_graph.edge_props[i] if particle not in other_particles: other_particles += particle @@ -550,11 +550,11 @@ def is_same_shape( for edge_id in ( graph.topology.incoming_edge_ids | graph.topology.outgoing_edge_ids ): - edge_prop = merged_graph.get_edge_props(edge_id) + edge_prop = merged_graph.edge_props[edge_id] if len(edge_prop) != 1: return False other_particle = next(iter(edge_prop)) - if other_particle != graph.get_edge_props(edge_id): + if other_particle != graph.edge_props[edge_id]: return False return True @@ -569,15 +569,14 @@ def is_same_shape( break if append_to_inventory: new_edge_props = { - edge_id: ParticleCollection({graph.get_edge_props(edge_id)}) + edge_id: ParticleCollection({graph.edge_props[edge_id]}) for edge_id in graph.topology.edges } inventory.append( StateTransitionGraph[ParticleCollection]( topology=graph.topology, node_props={ - i: graph.get_node_props(i) - for i in graph.topology.nodes + i: graph.node_props[i] for i in graph.topology.nodes }, edge_props=new_edge_props, ) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index fe77cc49..88bcb602 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -655,6 +655,16 @@ def _attach_node_to_edges( """A `~typing.TypeVar` representing the type of edge properties.""" +def _cast_edges(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: + return dict(obj) + + +def _cast_nodes( + obj: Mapping[int, InteractionProperties] +) -> Dict[int, InteractionProperties]: + return dict(obj) + + @implement_pretty_repr() @define class StateTransitionGraph(Generic[EdgeType]): @@ -667,14 +677,8 @@ class StateTransitionGraph(Generic[EdgeType]): """ topology: Topology = field(validator=instance_of(Topology)) - node_props: Dict[int, InteractionProperties] - edge_props: Dict[int, EdgeType] - - def get_node_props(self, node_id: int) -> InteractionProperties: - return self.node_props[node_id] - - def get_edge_props(self, edge_id: int) -> EdgeType: - return self.edge_props[edge_id] + node_props: Dict[int, InteractionProperties] = field(converter=_cast_nodes) + edge_props: Dict[int, EdgeType] = field(converter=_cast_edges) def evolve( self, diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 5eda4ff3..3a66625a 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -698,17 +698,17 @@ def _match_final_state_ids( particle_names = _strip_spin(state_definition) name_to_id = {name: i for i, name in enumerate(particle_names)} id_remapping = { - name_to_id[graph.get_edge_props(i)[0].name]: i + name_to_id[graph.edge_props[i][0].name]: i for i in graph.topology.outgoing_edge_ids } new_topology = graph.topology.relabel_edges(id_remapping) return StateTransitionGraph( new_topology, edge_props={ - i: graph.get_edge_props(id_remapping.get(i, i)) + i: graph.edge_props[id_remapping.get(i, i)] for i in graph.topology.edges }, - node_props={i: graph.get_node_props(i) for i in graph.topology.nodes}, + node_props={i: graph.node_props[i] for i in graph.topology.nodes}, ) @@ -751,13 +751,10 @@ def from_graph( return StateTransition( topology=graph.topology, states=FrozenDict( - { - i: State(*graph.get_edge_props(i)) - for i in graph.topology.edges - } + {i: State(*graph.edge_props[i]) for i in graph.topology.edges} ), interactions=FrozenDict( - {i: graph.get_node_props(i) for i in graph.topology.nodes} + {i: graph.node_props[i] for i in graph.topology.nodes} ), ) diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index b2f49636..79257ff3 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -215,7 +215,7 @@ def test_collapse_graphs( graph = next(iter(collapsed_graphs)) edge_id = next(iter(graph.topology.intermediate_edge_ids)) f_resonances = pdg.filter(lambda p: p.name in ["f(0)(980)", "f(0)(1500)"]) - intermediate_states = graph.get_edge_props(edge_id) + intermediate_states = graph.edge_props[edge_id] assert isinstance(intermediate_states, ParticleCollection) assert intermediate_states == f_resonances @@ -224,15 +224,13 @@ def test_get_particle_graphs( reaction: ReactionInfo, particle_database: ParticleCollection ): pdg = particle_database - particle_graphs = _get_particle_graphs(reaction.to_graphs()) - assert len(particle_graphs) == 2 - assert particle_graphs[0].get_edge_props(3) == pdg["f(0)(980)"] - assert particle_graphs[1].get_edge_props(3) == pdg["f(0)(1500)"] - assert len(particle_graphs[0].topology.edges) == 5 - for edge_id in range(-1, 3): - assert particle_graphs[0].get_edge_props(edge_id) is particle_graphs[ - 1 - ].get_edge_props(edge_id) + graphs = _get_particle_graphs(reaction.to_graphs()) + assert len(graphs) == 2 + assert graphs[0].edge_props[3] == pdg["f(0)(980)"] + assert graphs[1].edge_props[3] == pdg["f(0)(1500)"] + assert len(graphs[0].topology.edges) == 5 + for i in range(-1, 3): + assert graphs[0].edge_props[i] is graphs[1].edge_props[i] def test_strip_projections(): @@ -256,8 +254,8 @@ def test_strip_projections(): assert transition.interactions[1].l_projection == 0 stripped_transition = _strip_projections(transition) # type: ignore[arg-type] - assert stripped_transition.get_edge_props(3).name == resonance - assert stripped_transition.get_node_props(0).s_projection is None - assert stripped_transition.get_node_props(0).l_projection is None - assert stripped_transition.get_node_props(1).s_projection is None - assert stripped_transition.get_node_props(1).l_projection is None + assert stripped_transition.edge_props[3].name == resonance + assert stripped_transition.node_props[0].s_projection is None + assert stripped_transition.node_props[0].l_projection is None + assert stripped_transition.node_props[1].s_projection is None + assert stripped_transition.node_props[1].l_projection is None diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index a537550b..6c0d2cd5 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -380,15 +380,15 @@ def test_edge_swap(particle_database, initial_state, final_state): edge_keys = list(ref_mapping.keys()) edge1 = edge_keys[0] edge1_val = graph.topology.edges[edge1] - edge1_props = deepcopy(graph.get_edge_props(edge1)) + edge1_props = deepcopy(graph.edge_props[edge1]) edge2 = edge_keys[1] edge2_val = graph.topology.edges[edge2] - edge2_props = deepcopy(graph.get_edge_props(edge2)) + edge2_props = deepcopy(graph.edge_props[edge2]) graph.swap_edges(edge1, edge2) assert graph.topology.edges[edge1] == edge2_val assert graph.topology.edges[edge2] == edge1_val - assert graph.get_edge_props(edge1) == edge2_props - assert graph.get_edge_props(edge2) == edge1_props + assert graph.edge_props[edge1] == edge2_props + assert graph.edge_props[edge2] == edge1_props @pytest.mark.parametrize( From 6e5240ae1f3b66a03e2822ee814ffe687120248d Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:31 +0100 Subject: [PATCH 07/34] refactor: remove StateTransitionGraph.evolve() --- src/qrules/_system_control.py | 2 +- src/qrules/topology.py | 29 +---------------------------- tests/unit/test_system_control.py | 6 ++++-- 3 files changed, 6 insertions(+), 31 deletions(-) diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index acae3754..7eeb4b93 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -238,7 +238,7 @@ def _remove_qns_from_graph( # pylint: disable=too-many-branches node_props, **{x.__name__: None for x in qn_list} ) - return graph.evolve(node_props=new_node_props) + return attrs.evolve(graph, node_props=new_node_props) def _check_equal_ignoring_qns( diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 88bcb602..c294801e 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -680,34 +680,6 @@ class StateTransitionGraph(Generic[EdgeType]): node_props: Dict[int, InteractionProperties] = field(converter=_cast_nodes) edge_props: Dict[int, EdgeType] = field(converter=_cast_edges) - def evolve( - self, - node_props: Optional[Dict[int, InteractionProperties]] = None, - edge_props: Optional[Dict[int, EdgeType]] = None, - ) -> "StateTransitionGraph[EdgeType]": - """Changes the node and edge properties of a graph instance. - - Since a `.StateTransitionGraph` is frozen (cannot be modified), the - evolve function will also create a shallow copy the properties. - """ - new_node_props = copy.copy(self.node_props) - if node_props: - _assert_not_overdefined(self.topology.nodes, node_props) - for node_id, node_prop in node_props.items(): - new_node_props[node_id] = node_prop - - new_edge_props = copy.copy(self.edge_props) - if edge_props: - _assert_not_overdefined(self.topology.edges, edge_props) - for edge_id, edge_prop in edge_props.items(): - new_edge_props[edge_id] = edge_prop - - return StateTransitionGraph[EdgeType]( - topology=self.topology, - node_props=new_node_props, - edge_props=new_edge_props, - ) - def compare( self, other: "StateTransitionGraph", @@ -757,6 +729,7 @@ def _assert_all_defined(items: Iterable, properties: Iterable) -> None: ) +# pyright: reportUnusedFunction=false def _assert_not_overdefined(items: Iterable, properties: Iterable) -> None: existing = set(items) defined = set(properties) diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index 6c0d2cd5..dec8b9d5 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -2,6 +2,7 @@ from copy import deepcopy from typing import List +import attrs import pytest from qrules import InteractionType, ProblemSet, StateTransitionManager @@ -324,13 +325,14 @@ def test_filter_graphs_for_interaction_qns( for value in input_values: tempgraph = make_ls_test_graph(value[1][0], value[1][1], pi0) - tempgraph = tempgraph.evolve( + tempgraph = attrs.evolve( + tempgraph, edge_props={ 0: ( Particle(name=value[0], pid=0, mass=1.0, spin=1.0), 0.0, ) - } + }, ) graphs.append(tempgraph) From 7b00d5cd06f3dcc0aed1a3876253abc4e6a55611 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:32 +0100 Subject: [PATCH 08/34] refactor: make NodeType of StateTransitionGraph generic --- src/qrules/_system_control.py | 25 +++++--- src/qrules/combinatorics.py | 17 ++++-- src/qrules/io/_dict.py | 8 ++- src/qrules/io/_dot.py | 40 +++++++----- src/qrules/topology.py | 18 +++--- src/qrules/transition.py | 30 +++++---- tests/channels/test_jpsi_to_gamma_pi0_pi0.py | 4 +- tests/unit/test_system_control.py | 64 +++++++++++--------- 8 files changed, 121 insertions(+), 85 deletions(-) diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index 7eeb4b93..df872470 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -197,10 +197,12 @@ def check( def remove_duplicate_solutions( - solutions: List[StateTransitionGraph[ParticleWithSpin]], + solutions: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ], remove_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, ignore_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, -) -> List[StateTransitionGraph[ParticleWithSpin]]: +) -> List[StateTransitionGraph[ParticleWithSpin, InteractionProperties]]: if remove_qns_list is None: remove_qns_list = set() if ignore_qns_list is None: @@ -209,7 +211,9 @@ def remove_duplicate_solutions( logging.info(f"removing these qns from graphs: {remove_qns_list}") logging.info(f"ignoring qns in graph comparison: {ignore_qns_list}") - filtered_solutions: List[StateTransitionGraph[ParticleWithSpin]] = [] + filtered_solutions: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ] = [] remove_counter = 0 for sol_graph in solutions: sol_graph = _remove_qns_from_graph(sol_graph, remove_qns_list) @@ -228,9 +232,9 @@ def remove_duplicate_solutions( def _remove_qns_from_graph( # pylint: disable=too-many-branches - graph: StateTransitionGraph[ParticleWithSpin], + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], qn_list: Set[Type[NodeQuantumNumber]], -) -> StateTransitionGraph[ParticleWithSpin]: +) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: new_node_props = {} for node_id in graph.topology.nodes: node_props = graph.node_props[node_id] @@ -326,7 +330,9 @@ def require_interaction_property( ingoing_particle_name: str, interaction_qn: Type[NodeQuantumNumber], allowed_values: List, -) -> Callable[[StateTransitionGraph[ParticleWithSpin]], bool]: +) -> Callable[ + [StateTransitionGraph[ParticleWithSpin, InteractionProperties]], bool +]: """Filter function. Closure, which can be used as a filter function in :func:`.filter_graphs`. @@ -351,7 +357,9 @@ def require_interaction_property( - *False* otherwise """ - def check(graph: StateTransitionGraph[ParticleWithSpin]) -> bool: + def check( + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ) -> bool: node_ids = _find_node_ids_with_ingoing_particle_name( graph, ingoing_particle_name ) @@ -369,7 +377,8 @@ def check(graph: StateTransitionGraph[ParticleWithSpin]) -> bool: def _find_node_ids_with_ingoing_particle_name( - graph: StateTransitionGraph[ParticleWithSpin], ingoing_particle_name: str + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + ingoing_particle_name: str, ) -> List[int]: topology = graph.topology found_node_ids = [] diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 5674d005..733152bb 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -406,19 +406,21 @@ def populate_edge_with_spin_projections( def __get_initial_state_edge_ids( - graph: StateTransitionGraph[ParticleWithSpin], + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], ) -> Iterable[int]: return graph.topology.incoming_edge_ids def __get_final_state_edge_ids( - graph: StateTransitionGraph[ParticleWithSpin], + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], ) -> Iterable[int]: return graph.topology.outgoing_edge_ids def match_external_edges( - graphs: List[StateTransitionGraph[ParticleWithSpin]], + graphs: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ], ) -> None: if not isinstance(graphs, list): raise TypeError("graphs argument is not of type list") @@ -432,7 +434,9 @@ def match_external_edges( def _match_external_edge_ids( # pylint: disable=too-many-locals - graphs: List[StateTransitionGraph[ParticleWithSpin]], + graphs: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ], ref_graph_id: int, external_edge_getter_function: Callable[ [StateTransitionGraph], Iterable[int] @@ -497,7 +501,7 @@ def perform_external_edge_identical_particle_combinatorics( def _external_edge_identical_particle_combinatorics( - graph: StateTransitionGraph[ParticleWithSpin], + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], external_edge_getter_function: Callable[ [StateTransitionGraph], Iterable[int] ], @@ -557,6 +561,7 @@ def _calculate_swappings(id_mapping: Dict[int, int]) -> OrderedDict: def _create_edge_id_particle_mapping( - graph: StateTransitionGraph[ParticleWithSpin], edge_ids: Iterable[int] + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + edge_ids: Iterable[int], ) -> Dict[int, str]: return {i: graph.edge_props[i][0].name for i in edge_ids} diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index 4746b977..7af6b402 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -33,7 +33,9 @@ def from_particle(particle: Particle) -> dict: ) -def from_stg(graph: StateTransitionGraph[ParticleWithSpin]) -> dict: +def from_stg( + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties] +) -> dict: topology = graph.topology edge_props_def = {} for i in topology.edges: @@ -116,7 +118,9 @@ def build_reaction_info(definition: dict) -> ReactionInfo: return ReactionInfo(transitions, formalism=definition["formalism"]) -def build_stg(definition: dict) -> StateTransitionGraph[ParticleWithSpin]: +def build_stg( + definition: dict, +) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: topology = build_topology(definition["topology"]) edge_props_def: Dict[int, dict] = definition["edge_props"] edge_props: Dict[int, ParticleWithSpin] = {} diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index c63ecfdd..56ccb68e 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -466,8 +466,10 @@ def __extract_priority(description: str) -> int: def _get_particle_graphs( - graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], -) -> List[StateTransitionGraph[Particle]]: + graphs: Iterable[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ], +) -> List[StateTransitionGraph[Particle, InteractionProperties]]: """Strip `list` of `.StateTransitionGraph` s of the spin projections. Extract a `list` of `.StateTransitionGraph` instances with only @@ -475,7 +477,7 @@ def _get_particle_graphs( .. seealso:: :doc:`/usage/visualize` """ - inventory: List[StateTransitionGraph[Particle]] = [] + inventory: List[StateTransitionGraph[Particle, InteractionProperties]] = [] for transition in graphs: if isinstance(transition, StateTransition): transition = transition.to_graph() @@ -498,8 +500,8 @@ def _get_particle_graphs( def _strip_projections( - graph: StateTransitionGraph[ParticleWithSpin], -) -> StateTransitionGraph[Particle]: + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], +) -> StateTransitionGraph[Particle, InteractionProperties]: if isinstance(graph, StateTransition): graph = graph.to_graph() new_edge_props = {} @@ -514,7 +516,7 @@ def _strip_projections( new_node_props[node_id] = attrs.evolve( node_props, l_projection=None, s_projection=None ) - return StateTransitionGraph[Particle]( + return StateTransitionGraph[Particle, InteractionProperties]( topology=graph.topology, node_props=new_node_props, edge_props=new_edge_props, @@ -522,11 +524,15 @@ def _strip_projections( def _collapse_graphs( - graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], -) -> List[StateTransitionGraph[ParticleCollection]]: + graphs: Iterable[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ], +) -> List[StateTransitionGraph[ParticleCollection, InteractionProperties]]: def merge_into( - graph: StateTransitionGraph[Particle], - merged_graph: StateTransitionGraph[ParticleCollection], + graph: StateTransitionGraph[Particle, InteractionProperties], + merged_graph: StateTransitionGraph[ + ParticleCollection, InteractionProperties + ], ) -> None: if ( graph.topology.intermediate_edge_ids @@ -542,8 +548,10 @@ def merge_into( other_particles += particle def is_same_shape( - graph: StateTransitionGraph[Particle], - merged_graph: StateTransitionGraph[ParticleCollection], + graph: StateTransitionGraph[Particle, InteractionProperties], + merged_graph: StateTransitionGraph[ + ParticleCollection, InteractionProperties + ], ) -> bool: if graph.topology.edges != merged_graph.topology.edges: return False @@ -559,7 +567,9 @@ def is_same_shape( return True particle_graphs = _get_particle_graphs(graphs) - inventory: List[StateTransitionGraph[ParticleCollection]] = [] + inventory: List[ + StateTransitionGraph[ParticleCollection, InteractionProperties] + ] = [] for graph in particle_graphs: append_to_inventory = True for merged_graph in inventory: @@ -573,7 +583,9 @@ def is_same_shape( for edge_id in graph.topology.edges } inventory.append( - StateTransitionGraph[ParticleCollection]( + StateTransitionGraph[ + ParticleCollection, InteractionProperties + ]( topology=graph.topology, node_props={ i: graph.node_props[i] for i in graph.topology.nodes diff --git a/src/qrules/topology.py b/src/qrules/topology.py index c294801e..0f5b61e4 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -42,8 +42,6 @@ from qrules._implementers import implement_pretty_repr -from .quantum_numbers import InteractionProperties - if sys.version_info >= (3, 8): from typing import Protocol else: @@ -653,21 +651,21 @@ def _attach_node_to_edges( EdgeType = TypeVar("EdgeType") """A `~typing.TypeVar` representing the type of edge properties.""" +NodeType = TypeVar("NodeType") +"""A `~typing.TypeVar` representing the type of node properties.""" def _cast_edges(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: return dict(obj) -def _cast_nodes( - obj: Mapping[int, InteractionProperties] -) -> Dict[int, InteractionProperties]: +def _cast_nodes(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: return dict(obj) -@implement_pretty_repr() +@implement_pretty_repr @define -class StateTransitionGraph(Generic[EdgeType]): +class StateTransitionGraph(Generic[EdgeType, NodeType]): """Graph class that resembles a frozen `.Topology` with properties. This class should contain the full information of a state transition from a @@ -677,16 +675,14 @@ class StateTransitionGraph(Generic[EdgeType]): """ topology: Topology = field(validator=instance_of(Topology)) - node_props: Dict[int, InteractionProperties] = field(converter=_cast_nodes) + node_props: Dict[int, NodeType] = field(converter=_cast_nodes) edge_props: Dict[int, EdgeType] = field(converter=_cast_edges) def compare( self, other: "StateTransitionGraph", edge_comparator: Optional[Callable[[EdgeType, EdgeType], bool]] = None, - node_comparator: Optional[ - Callable[[InteractionProperties, InteractionProperties], bool] - ] = None, + node_comparator: Optional[Callable[[NodeType, NodeType], bool]] = None, ) -> bool: if self.topology != other.topology: return False diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 3a66625a..5385cdf5 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -137,9 +137,9 @@ def clear(self) -> None: class _SolutionContainer: """Defines a result of a `.ProblemSet`.""" - solutions: List[StateTransitionGraph[ParticleWithSpin]] = field( - factory=list - ) + solutions: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ] = field(factory=list) execution_info: ExecutionInfo = field(default=ExecutionInfo()) def __attrs_post_init__(self) -> None: @@ -653,7 +653,9 @@ def __convert_result( """ solutions = [] for solution in qn_result.solutions: - graph = StateTransitionGraph[ParticleWithSpin]( + graph = StateTransitionGraph[ + ParticleWithSpin, InteractionProperties + ]( topology=topology, node_props={ i: create_interaction_properties(x) @@ -691,9 +693,9 @@ def _safe_wrap_list( def _match_final_state_ids( - graph: StateTransitionGraph[ParticleWithSpin], + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], state_definition: Sequence[StateDefinition], -) -> StateTransitionGraph[ParticleWithSpin]: +) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: """Temporary fix to https://github.com/ComPWA/qrules/issues/143.""" particle_names = _strip_spin(state_definition) name_to_id = {name: i for i, name in enumerate(particle_names)} @@ -746,7 +748,7 @@ def __attrs_post_init__(self) -> None: @staticmethod def from_graph( - graph: StateTransitionGraph[ParticleWithSpin], + graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], ) -> "StateTransition": return StateTransition( topology=graph.topology, @@ -758,8 +760,10 @@ def from_graph( ), ) - def to_graph(self) -> StateTransitionGraph[ParticleWithSpin]: - return StateTransitionGraph[ParticleWithSpin]( + def to_graph( + self, + ) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: + return StateTransitionGraph[ParticleWithSpin, InteractionProperties]( topology=self.topology, edge_props={ i: (state.particle, state.spin_projection) @@ -831,13 +835,17 @@ def get_intermediate_particles(self) -> ParticleCollection: @staticmethod def from_graphs( - graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], + graphs: Iterable[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ], formalism: str, ) -> "ReactionInfo": transitions = [StateTransition.from_graph(g) for g in graphs] return ReactionInfo(transitions, formalism) - def to_graphs(self) -> List[StateTransitionGraph[ParticleWithSpin]]: + def to_graphs( + self, + ) -> List[StateTransitionGraph[ParticleWithSpin, InteractionProperties]]: return [transition.to_graph() for transition in self.transitions] def group_by_topology(self) -> Dict[Topology, List[StateTransition]]: diff --git a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py index 7ec5dc7e..54a8cbec 100644 --- a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py +++ b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py @@ -2,8 +2,6 @@ import qrules from qrules.combinatorics import _create_edge_id_particle_mapping -from qrules.particle import ParticleWithSpin -from qrules.topology import StateTransitionGraph @pytest.mark.parametrize( @@ -62,7 +60,7 @@ def test_id_to_particle_mappings(particle_database): assert len(reaction.transitions) == 4 iter_transitions = iter(reaction.transitions) first_transition = next(iter_transitions) - graph: StateTransitionGraph[ParticleWithSpin] = first_transition.to_graph() + graph = first_transition.to_graph() ref_mapping_fs = _create_edge_id_particle_mapping( graph, graph.topology.outgoing_edge_ids ) diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index dec8b9d5..c030cc66 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -1,6 +1,6 @@ # pylint: disable=protected-access from copy import deepcopy -from typing import List +from typing import Dict, List import attrs import pytest @@ -208,38 +208,36 @@ def test_create_edge_properties( def make_ls_test_graph( angular_momentum_magnitude, coupled_spin_magnitude, particle ): - graph = StateTransitionGraph[ParticleWithSpin]( - topology=Topology( - nodes={0}, - edges={0: Edge(None, 0)}, - ), - node_props={ - 0: InteractionProperties( - s_magnitude=coupled_spin_magnitude, - l_magnitude=angular_momentum_magnitude, - ) - }, - edge_props={0: (particle, 0)}, + topology = Topology( + nodes={0}, + edges={0: Edge(None, 0)}, ) + node_props = { + 0: InteractionProperties( + s_magnitude=coupled_spin_magnitude, + l_magnitude=angular_momentum_magnitude, + ) + } + edge_props: Dict[int, ParticleWithSpin] = {0: (particle, 0)} + graph = StateTransitionGraph(topology, node_props, edge_props) return graph def make_ls_test_graph_scrambled( angular_momentum_magnitude, coupled_spin_magnitude, particle ): - graph = StateTransitionGraph[ParticleWithSpin]( - topology=Topology( - nodes={0}, - edges={0: Edge(None, 0)}, - ), - node_props={ - 0: InteractionProperties( - l_magnitude=angular_momentum_magnitude, - s_magnitude=coupled_spin_magnitude, - ) - }, - edge_props={0: (particle, 0)}, + topology = Topology( + nodes={0}, + edges={0: Edge(None, 0)}, ) + node_props = { + 0: InteractionProperties( + l_magnitude=angular_momentum_magnitude, + s_magnitude=coupled_spin_magnitude, + ) + } + edge_props: Dict[int, ParticleWithSpin] = {0: (particle, 0)} + graph = StateTransitionGraph(topology, node_props, edge_props) return graph @@ -343,8 +341,8 @@ def test_filter_graphs_for_interaction_qns( def _create_graph( problem_set: ProblemSet, -) -> StateTransitionGraph[ParticleWithSpin]: - return StateTransitionGraph[ParticleWithSpin]( +) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: + return StateTransitionGraph( topology=problem_set.topology, node_props=problem_set.initial_facts.node_props, edge_props=problem_set.initial_facts.edge_props, @@ -371,7 +369,9 @@ def test_edge_swap(particle_database, initial_state, final_state): stm.set_allowed_interaction_types([InteractionType.STRONG]) problem_sets = stm.create_problem_sets() - init_graphs: List[StateTransitionGraph[ParticleWithSpin]] = [] + init_graphs: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ] = [] for _, problem_set_list in problem_sets.items(): init_graphs.extend([_create_graph(x) for x in problem_set_list]) @@ -417,7 +417,9 @@ def test_match_external_edges(particle_database, initial_state, final_state): stm.set_allowed_interaction_types([InteractionType.STRONG]) problem_sets = stm.create_problem_sets() - init_graphs: List[StateTransitionGraph[ParticleWithSpin]] = [] + init_graphs: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ] = [] for _, problem_set_list in problem_sets.items(): init_graphs.extend([_create_graph(x) for x in problem_set_list]) @@ -506,7 +508,9 @@ def test_external_edge_identical_particle_combinatorics( match_external_edges(init_graphs) - comb_graphs: List[StateTransitionGraph[ParticleWithSpin]] = [] + comb_graphs: List[ + StateTransitionGraph[ParticleWithSpin, InteractionProperties] + ] = [] for group in init_graphs: comb_graphs.extend( perform_external_edge_identical_particle_combinatorics(group) From 9d1a3fbc185ce68e35290e5498d153120501b0ba Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:32 +0100 Subject: [PATCH 09/34] refactor: rename STG to MutableTransition --- docs/conf.py | 2 +- docs/index.md | 4 +-- docs/usage/conservation.ipynb | 2 +- docs/usage/reaction.ipynb | 6 ++-- docs/usage/visualize.ipynb | 4 +-- src/qrules/__init__.py | 4 +-- src/qrules/_system_control.py | 44 +++++++++++++------------- src/qrules/combinatorics.py | 44 ++++++++++++-------------- src/qrules/io/__init__.py | 12 +++---- src/qrules/io/_dict.py | 8 ++--- src/qrules/io/_dot.py | 52 +++++++++++++++---------------- src/qrules/particle.py | 2 +- src/qrules/quantum_numbers.py | 2 +- src/qrules/solving.py | 8 ++--- src/qrules/topology.py | 8 ++--- src/qrules/transition.py | 28 ++++++++--------- tests/unit/io/test_io.py | 6 ++-- tests/unit/test_system_control.py | 16 +++++----- tests/unit/test_transition.py | 2 +- 19 files changed, 122 insertions(+), 132 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 59faa962..bd356814 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -223,7 +223,7 @@ def fetch_logo(url: str, output_path: str) -> None: nitpick_ignore = [ ("py:class", "EdgeType"), ("py:class", "NoneType"), - ("py:class", "StateTransitionGraph"), + ("py:class", "MutableTransition"), ("py:class", "ValueType"), ("py:class", "json.encoder.JSONEncoder"), ("py:class", "typing_extensions.Protocol"), diff --git a/docs/index.md b/docs/index.md index 49bee108..e50e88b1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,7 +47,7 @@ QRules consists of three major components: 1. **State transition graphs** - A {class}`.StateTransitionGraph` is a + A {class}`.MutableTransition` is a [directed graph](https://en.wikipedia.org/wiki/Directed_graph) that consists of **nodes** and **edges**. In a directed graph, each edge must be connected to at least one node (in correspondence to @@ -93,7 +93,7 @@ The main solver used by {mod}`qrules` is the 1. **Preparation** 1.1. Build all possible topologies. A **topology** is represented by a - {class}`.StateTransitionGraph`, in which the edges and nodes are empty (no + {class}`.MutableTransition`, in which the edges and nodes are empty (no particle information). 1.2. Fill the topology graphs with the user provided information. Typically diff --git a/docs/usage/conservation.ipynb b/docs/usage/conservation.ipynb index 6c6b64f7..ec709090 100644 --- a/docs/usage/conservation.ipynb +++ b/docs/usage/conservation.ipynb @@ -97,7 +97,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "QRules generates {class}`.StateTransitionGraph`s, populates them with quantum numbers (edge properties representing states and nodes properties representing interactions), then checks whether the generated {class}`.StateTransitionGraph`s comply with the rules formulated in the {mod}`.conservation_rules` module.\n", + "QRules generates {class}`.MutableTransition`s, populates them with quantum numbers (edge properties representing states and nodes properties representing interactions), then checks whether the generated {class}`.MutableTransition`s comply with the rules formulated in the {mod}`.conservation_rules` module.\n", "\n", "The {mod}`.conservation_rules` module can also be used separately. In this notebook, we will illustrate this by checking spin and parity conservation." ] diff --git a/docs/usage/reaction.ipynb b/docs/usage/reaction.ipynb index 41830dd5..2b1b026e 100644 --- a/docs/usage/reaction.ipynb +++ b/docs/usage/reaction.ipynb @@ -252,7 +252,7 @@ "1. One the one hand, the {attr}`.EdgeSettings.qn_domains` and {attr}`.NodeSettings.qn_domains` contained in the {class}`~.GraphSettings` define the **domain** over which quantum number sets can be generated.\n", "2. On the other, the {attr}`.EdgeSettings.rule_priorities` and {attr}`.NodeSettings.rule_priorities` in {class}`~.GraphSettings` define which **{mod}`.conservation_rules`** are used to determine which of the sets of generated quantum numbers are valid.\n", "\n", - "Together, these two constraints allow the {class}`.StateTransitionManager` to generate a number of {class}`.StateTransitionGraph`s that comply with the selected {mod}`.conservation_rules`." + "Together, these two constraints allow the {class}`.StateTransitionManager` to generate a number of {class}`.MutableTransition`s that comply with the selected {mod}`.conservation_rules`." ] }, { @@ -313,7 +313,7 @@ "class: dropdown\n", "----\n", "\n", - "The \"number of {attr}`~.ReactionInfo.transitions`\" is the total number of allowed {obj}`.StateTransitionGraph` instances that the {class}`.StateTransitionManager` has found. This also includes all allowed **spin projection combinations**. In this channel, we for example consider a $J/\\psi$ with spin projection $\\pm1$ that decays into a $\\gamma$ with spin projection $\\pm1$, which already gives us four possibilities.\n", + "The \"number of {attr}`~.ReactionInfo.transitions`\" is the total number of allowed {obj}`.MutableTransition` instances that the {class}`.StateTransitionManager` has found. This also includes all allowed **spin projection combinations**. In this channel, we for example consider a $J/\\psi$ with spin projection $\\pm1$ that decays into a $\\gamma$ with spin projection $\\pm1$, which already gives us four possibilities.\n", "\n", "On the other hand, the intermediate state names that was extracted with {meth}`.ReactionInfo.get_intermediate_particles`, is just a {obj}`set` of the state names on the intermediate edges of the list of {attr}`~.ReactionInfo.transitions`, regardless of spin projection.\n", "````" @@ -457,7 +457,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {class}`.ReactionInfo`, {class}`.StateTransitionGraph`, and {class}`.Topology` can be serialized to and from a {obj}`dict` with {func}`.io.asdict` and {func}`.io.fromdict`:" + "The {class}`.ReactionInfo`, {class}`.MutableTransition`, and {class}`.Topology` can be serialized to and from a {obj}`dict` with {func}`.io.asdict` and {func}`.io.fromdict`:" ] }, { diff --git a/docs/usage/visualize.ipynb b/docs/usage/visualize.ipynb index eb663fb3..061aa1ff 100644 --- a/docs/usage/visualize.ipynb +++ b/docs/usage/visualize.ipynb @@ -81,7 +81,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {mod}`~qrules.io` module allows you to convert {class}`.StateTransitionGraph`, {class}`.Topology` instances, and {class}`.ProblemSet`s to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.ReactionInfo` object with a {class}`.list` of {class}`.StateTransitionGraph` instances (see {doc}`/usage/reaction`)." + "The {mod}`~qrules.io` module allows you to convert {class}`.MutableTransition`, {class}`.Topology` instances, and {class}`.ProblemSet`s to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.ReactionInfo` object with a {class}`.list` of {class}`.MutableTransition` instances (see {doc}`/usage/reaction`)." ] }, { @@ -300,7 +300,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This {class}`str` of [DOT language](https://graphviz.org/doc/info/lang.html) for the list of {class}`.StateTransitionGraph` instances can then be visualized with a third-party library, for instance, with {class}`graphviz.Source`:" + "This {class}`str` of [DOT language](https://graphviz.org/doc/info/lang.html) for the list of {class}`.MutableTransition` instances can then be visualized with a third-party library, for instance, with {class}`graphviz.Source`:" ] }, { diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 71c5c990..31226fcf 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -7,9 +7,9 @@ and allowed interactions. The core of `qrules` computes which transitions (represented by a -`.StateTransitionGraph`) are allowed between a certain initial and final state. +`.MutableTransition`) are allowed between a certain initial and final state. Internally, the system propagates the quantum numbers defined by the -`particle` module through the `.StateTransitionGraph`, while +`particle` module through the `.MutableTransition`, while satisfying the rules define by the :mod:`.conservation_rules` module. See :doc:`/usage/reaction` and :doc:`/usage/particle`. diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index df872470..974be04e 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -17,12 +17,12 @@ ) from .settings import InteractionType from .solving import GraphEdgePropertyMap, GraphNodePropertyMap, GraphSettings -from .topology import StateTransitionGraph +from .topology import MutableTransition Strength = float GraphSettingsGroups = Dict[ - Strength, List[Tuple[StateTransitionGraph, GraphSettings]] + Strength, List[Tuple[MutableTransition, GraphSettings]] ] @@ -198,11 +198,11 @@ def check( def remove_duplicate_solutions( solutions: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ], remove_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, ignore_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, -) -> List[StateTransitionGraph[ParticleWithSpin, InteractionProperties]]: +) -> List[MutableTransition[ParticleWithSpin, InteractionProperties]]: if remove_qns_list is None: remove_qns_list = set() if ignore_qns_list is None: @@ -212,7 +212,7 @@ def remove_duplicate_solutions( logging.info(f"ignoring qns in graph comparison: {ignore_qns_list}") filtered_solutions: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ] = [] remove_counter = 0 for sol_graph in solutions: @@ -232,9 +232,9 @@ def remove_duplicate_solutions( def _remove_qns_from_graph( # pylint: disable=too-many-branches - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], qn_list: Set[Type[NodeQuantumNumber]], -) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: +) -> MutableTransition[ParticleWithSpin, InteractionProperties]: new_node_props = {} for node_id in graph.topology.nodes: node_props = graph.node_props[node_id] @@ -246,19 +246,17 @@ def _remove_qns_from_graph( # pylint: disable=too-many-branches def _check_equal_ignoring_qns( - ref_graph: StateTransitionGraph, - solutions: List[StateTransitionGraph], + ref_graph: MutableTransition, + solutions: List[MutableTransition], ignored_qn_list: Set[Type[NodeQuantumNumber]], -) -> Optional[StateTransitionGraph]: +) -> Optional[MutableTransition]: """Define equal operator for graphs, ignoring certain quantum numbers.""" - if not isinstance(ref_graph, StateTransitionGraph): - raise TypeError( - "Reference graph has to be of type StateTransitionGraph" - ) + if not isinstance(ref_graph, MutableTransition): + raise TypeError("Reference graph has to be of type MutableTransition") found_graph = None node_comparator = NodePropertyComparator(ignored_qn_list) for graph in solutions: - if isinstance(graph, StateTransitionGraph): + if isinstance(graph, MutableTransition): if graph.compare( ref_graph, edge_comparator=lambda e1, e2: e1 == e2, @@ -293,13 +291,13 @@ def __call__( def filter_graphs( - graphs: List[StateTransitionGraph], - filters: Iterable[Callable[[StateTransitionGraph], bool]], -) -> List[StateTransitionGraph]: - r"""Implement filtering of a list of `.StateTransitionGraph` 's. + graphs: List[MutableTransition], + filters: Iterable[Callable[[MutableTransition], bool]], +) -> List[MutableTransition]: + r"""Implement filtering of a list of `.MutableTransition` 's. This function can be used to select a subset of - `.StateTransitionGraph` 's from a list. Only the graphs passing + `.MutableTransition` 's from a list. Only the graphs passing all supplied filters will be returned. Note: @@ -331,7 +329,7 @@ def require_interaction_property( interaction_qn: Type[NodeQuantumNumber], allowed_values: List, ) -> Callable[ - [StateTransitionGraph[ParticleWithSpin, InteractionProperties]], bool + [MutableTransition[ParticleWithSpin, InteractionProperties]], bool ]: """Filter function. @@ -358,7 +356,7 @@ def require_interaction_property( """ def check( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties] + graph: MutableTransition[ParticleWithSpin, InteractionProperties] ) -> bool: node_ids = _find_node_ids_with_ingoing_particle_name( graph, ingoing_particle_name @@ -377,7 +375,7 @@ def check( def _find_node_ids_with_ingoing_particle_name( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], ingoing_particle_name: str, ) -> List[int]: topology = graph.topology diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 733152bb..567f08db 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -1,8 +1,8 @@ -"""Perform permutations on the edges of a `.StateTransitionGraph`. +"""Perform permutations on the edges of a `.MutableTransition`. -In a `.StateTransitionGraph`, the edges represent quantum states, while the -nodes represent interactions. This module provides tools to permutate, modify -or extract these edge and node properties. +In a `.MutableTransition`, the edges represent quantum states, while the nodes +represent interactions. This module provides tools to permutate, modify or +extract these edge and node properties. """ from collections import OrderedDict @@ -31,7 +31,7 @@ from .particle import ParticleWithSpin from .quantum_numbers import InteractionProperties, arange -from .topology import StateTransitionGraph, Topology, get_originating_node_list +from .topology import MutableTransition, Topology, get_originating_node_list StateWithSpins = Tuple[str, Sequence[float]] StateDefinition = Union[str, StateWithSpins] @@ -168,7 +168,7 @@ def _get_kinematic_representation( r"""Group final or initial states by node, sorted by length of the group. The resulting sorted groups can be used to check whether two - `.StateTransitionGraph` instances are kinematically identical. For + `.MutableTransition` instances are kinematically identical. For instance, the following two graphs: .. code-block:: @@ -406,21 +406,19 @@ def populate_edge_with_spin_projections( def __get_initial_state_edge_ids( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], ) -> Iterable[int]: return graph.topology.incoming_edge_ids def __get_final_state_edge_ids( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], ) -> Iterable[int]: return graph.topology.outgoing_edge_ids def match_external_edges( - graphs: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] - ], + graphs: List[MutableTransition[ParticleWithSpin, InteractionProperties]], ) -> None: if not isinstance(graphs, list): raise TypeError("graphs argument is not of type list") @@ -434,12 +432,10 @@ def match_external_edges( def _match_external_edge_ids( # pylint: disable=too-many-locals - graphs: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] - ], + graphs: List[MutableTransition[ParticleWithSpin, InteractionProperties]], ref_graph_id: int, external_edge_getter_function: Callable[ - [StateTransitionGraph], Iterable[int] + [MutableTransition], Iterable[int] ], ) -> None: ref_graph = graphs[ref_graph_id] @@ -475,17 +471,17 @@ def _match_external_edge_ids( # pylint: disable=too-many-locals def perform_external_edge_identical_particle_combinatorics( - graph: StateTransitionGraph, -) -> List[StateTransitionGraph]: - """Create combinatorics clones of the `.StateTransitionGraph`. + graph: MutableTransition, +) -> List[MutableTransition]: + """Create combinatorics clones of the `.MutableTransition`. In case of identical particles in the initial or final state. Only identical particles, which do not enter or exit the same node allow for combinatorics! """ - if not isinstance(graph, StateTransitionGraph): + if not isinstance(graph, MutableTransition): raise TypeError( - f"graph argument is not of type {StateTransitionGraph.__class__}" + f"graph argument is not of type {MutableTransition.__class__}" ) temp_new_graphs = _external_edge_identical_particle_combinatorics( graph, __get_final_state_edge_ids @@ -501,11 +497,11 @@ def perform_external_edge_identical_particle_combinatorics( def _external_edge_identical_particle_combinatorics( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], external_edge_getter_function: Callable[ - [StateTransitionGraph], Iterable[int] + [MutableTransition], Iterable[int] ], -) -> List[StateTransitionGraph]: +) -> List[MutableTransition]: # pylint: disable=too-many-locals new_graphs = [graph] edge_particle_mapping = _create_edge_id_particle_mapping( @@ -561,7 +557,7 @@ def _calculate_swappings(id_mapping: Dict[int, int]) -> OrderedDict: def _create_edge_id_particle_mapping( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], edge_ids: Iterable[int], ) -> Dict[int, str]: return {i: graph.edge_props[i][0].name for i in edge_ids} diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index 60b672f0..e4059b83 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -15,7 +15,7 @@ import yaml from qrules.particle import Particle, ParticleCollection -from qrules.topology import StateTransitionGraph, Topology +from qrules.topology import MutableTransition, Topology from qrules.transition import ProblemSet, ReactionInfo, State, StateTransition from . import _dict, _dot @@ -34,7 +34,7 @@ def asdict(instance: object) -> dict: filter=lambda a, _: a.init, value_serializer=_dict._value_serializer, ) - if isinstance(instance, StateTransitionGraph): + if isinstance(instance, MutableTransition): return _dict.from_stg(instance) if isinstance(instance, Topology): return _dict.from_topology(instance) @@ -87,13 +87,13 @@ def asdot( """Convert a `object` to a DOT language `str`. Only works for objects that can be represented as a graph, particularly a - `.StateTransitionGraph` or a `list` of `.StateTransitionGraph` instances. + `.MutableTransition` or a `list` of `.MutableTransition` instances. Args: instance: the input `object` that is to be rendered as DOT (graphviz) language. - strip_spin: Normally, each `.StateTransitionGraph` has a `.Particle` + strip_spin: Normally, each `.MutableTransition` has a `.Particle` with a spin projection on its edges. This option hides the projections, leaving only `.Particle` names on edges. @@ -102,7 +102,7 @@ def asdot( render_node: Whether or not to render node ID (in the case of a `.Topology`) and/or node properties (in the case of a - `.StateTransitionGraph`). Meaning of the labels: + `.MutableTransition`). Meaning of the labels: - :math:`P`: parity prefactor - :math:`s`: tuple of **coupled spin** magnitude and its @@ -131,7 +131,7 @@ def asdot( node_style = {} if isinstance(instance, StateTransition): instance = instance.to_graph() - if isinstance(instance, (ProblemSet, StateTransitionGraph, Topology)): + if isinstance(instance, (ProblemSet, MutableTransition, Topology)): dot = _dot.graph_to_dot( instance, render_node=render_node, diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index 7af6b402..2dbb3ec9 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -16,7 +16,7 @@ Spin, ) from qrules.quantum_numbers import InteractionProperties -from qrules.topology import Edge, StateTransitionGraph, Topology +from qrules.topology import Edge, MutableTransition, Topology from qrules.transition import ReactionInfo, State, StateTransition @@ -34,7 +34,7 @@ def from_particle(particle: Particle) -> dict: def from_stg( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties] + graph: MutableTransition[ParticleWithSpin, InteractionProperties] ) -> dict: topology = graph.topology edge_props_def = {} @@ -120,7 +120,7 @@ def build_reaction_info(definition: dict) -> ReactionInfo: def build_stg( definition: dict, -) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: +) -> MutableTransition[ParticleWithSpin, InteractionProperties]: topology = build_topology(definition["topology"]) edge_props_def: Dict[int, dict] = definition["edge_props"] edge_props: Dict[int, ParticleWithSpin] = {} @@ -135,7 +135,7 @@ def build_stg( int(i): InteractionProperties(**node_def) for i, node_def in node_props_def.items() } - return StateTransitionGraph( + return MutableTransition( topology=topology, edge_props=edge_props, node_props=node_props, diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 56ccb68e..8be48c69 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -24,7 +24,7 @@ from qrules.particle import Particle, ParticleCollection, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction from qrules.solving import EdgeSettings, GraphSettings, NodeSettings -from qrules.topology import StateTransitionGraph, Topology +from qrules.topology import MutableTransition, Topology from qrules.transition import ProblemSet, StateTransition _DOT_HEAD = """digraph { @@ -138,7 +138,7 @@ def __create_graphviz_assignments(graphviz_attrs: Dict[str, Any]) -> List[str]: @embed_dot def graph_list_to_dot( - graphs: Iterable[StateTransitionGraph], + graphs: Iterable[MutableTransition], *, render_node: bool, render_final_state_id: bool, @@ -188,7 +188,7 @@ def graph_list_to_dot( @embed_dot def graph_to_dot( - graph: StateTransitionGraph, + graph: MutableTransition, *, render_node: bool, render_final_state_id: bool, @@ -212,7 +212,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals graph: Union[ ProblemSet, StateTransition, - StateTransitionGraph, + MutableTransition, Topology, Tuple[Topology, InitialFacts], Tuple[Topology, GraphSettings], @@ -234,13 +234,13 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals InitialFacts, ProblemSet, StateTransition, - StateTransitionGraph, + MutableTransition, Topology, ] = graph[1] elif isinstance(graph, ProblemSet): rendered_graph = graph topology = graph.topology - elif isinstance(graph, (StateTransition, StateTransitionGraph)): + elif isinstance(graph, (StateTransition, MutableTransition)): rendered_graph = graph topology = graph.topology elif isinstance(graph, Topology): @@ -289,7 +289,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals label=node_label, graphviz_attrs=node_style, ) - if isinstance(graph, (StateTransition, StateTransitionGraph)): + if isinstance(graph, (StateTransition, MutableTransition)): if isinstance(graph, StateTransition): interactions: Mapping[ int, InteractionProperties @@ -337,7 +337,7 @@ def __get_edge_label( InitialFacts, ProblemSet, StateTransition, - StateTransitionGraph, + MutableTransition, Topology, ], edge_id: int, @@ -360,7 +360,7 @@ def __get_edge_label( return ___render_edge_with_id(edge_id, edge_property, render_edge_id) if isinstance(graph, StateTransition): graph = graph.to_graph() - if isinstance(graph, StateTransitionGraph): + if isinstance(graph, MutableTransition): edge_prop = graph.edge_props[edge_id] return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) if isinstance(graph, Topology): @@ -467,17 +467,17 @@ def __extract_priority(description: str) -> int: def _get_particle_graphs( graphs: Iterable[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ], -) -> List[StateTransitionGraph[Particle, InteractionProperties]]: - """Strip `list` of `.StateTransitionGraph` s of the spin projections. +) -> List[MutableTransition[Particle, InteractionProperties]]: + """Strip `list` of `.MutableTransition` s of the spin projections. - Extract a `list` of `.StateTransitionGraph` instances with only + Extract a `list` of `.MutableTransition` instances with only `.Particle` instances on the edges. .. seealso:: :doc:`/usage/visualize` """ - inventory: List[StateTransitionGraph[Particle, InteractionProperties]] = [] + inventory: List[MutableTransition[Particle, InteractionProperties]] = [] for transition in graphs: if isinstance(transition, StateTransition): transition = transition.to_graph() @@ -500,8 +500,8 @@ def _get_particle_graphs( def _strip_projections( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], -) -> StateTransitionGraph[Particle, InteractionProperties]: + graph: MutableTransition[ParticleWithSpin, InteractionProperties], +) -> MutableTransition[Particle, InteractionProperties]: if isinstance(graph, StateTransition): graph = graph.to_graph() new_edge_props = {} @@ -516,7 +516,7 @@ def _strip_projections( new_node_props[node_id] = attrs.evolve( node_props, l_projection=None, s_projection=None ) - return StateTransitionGraph[Particle, InteractionProperties]( + return MutableTransition[Particle, InteractionProperties]( topology=graph.topology, node_props=new_node_props, edge_props=new_edge_props, @@ -525,12 +525,12 @@ def _strip_projections( def _collapse_graphs( graphs: Iterable[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ], -) -> List[StateTransitionGraph[ParticleCollection, InteractionProperties]]: +) -> List[MutableTransition[ParticleCollection, InteractionProperties]]: def merge_into( - graph: StateTransitionGraph[Particle, InteractionProperties], - merged_graph: StateTransitionGraph[ + graph: MutableTransition[Particle, InteractionProperties], + merged_graph: MutableTransition[ ParticleCollection, InteractionProperties ], ) -> None: @@ -548,8 +548,8 @@ def merge_into( other_particles += particle def is_same_shape( - graph: StateTransitionGraph[Particle, InteractionProperties], - merged_graph: StateTransitionGraph[ + graph: MutableTransition[Particle, InteractionProperties], + merged_graph: MutableTransition[ ParticleCollection, InteractionProperties ], ) -> bool: @@ -568,7 +568,7 @@ def is_same_shape( particle_graphs = _get_particle_graphs(graphs) inventory: List[ - StateTransitionGraph[ParticleCollection, InteractionProperties] + MutableTransition[ParticleCollection, InteractionProperties] ] = [] for graph in particle_graphs: append_to_inventory = True @@ -583,9 +583,7 @@ def is_same_shape( for edge_id in graph.topology.edges } inventory.append( - StateTransitionGraph[ - ParticleCollection, InteractionProperties - ]( + MutableTransition[ParticleCollection, InteractionProperties]( topology=graph.topology, node_props={ i: graph.node_props[i] for i in graph.topology.nodes diff --git a/src/qrules/particle.py b/src/qrules/particle.py index 6f2c1ae8..f2140a26 100644 --- a/src/qrules/particle.py +++ b/src/qrules/particle.py @@ -8,7 +8,7 @@ :doc:`/usage/particle`). The `.transition` module uses the properties of `Particle` instances when it -computes which `.StateTransitionGraph` s are allowed between an initial state +computes which `.MutableTransition` s are allowed between an initial state and final state. """ diff --git a/src/qrules/quantum_numbers.py b/src/qrules/quantum_numbers.py index e76c48f5..4385fd2c 100644 --- a/src/qrules/quantum_numbers.py +++ b/src/qrules/quantum_numbers.py @@ -169,7 +169,7 @@ def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: class InteractionProperties: """Immutable data structure containing interaction properties. - Interactions are represented by a node on a `.StateTransitionGraph`. This + Interactions are represented by a node on a `.MutableTransition`. This class represents the properties that are carried collectively by the edges that this node connects. diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 9502bb51..7fcf851a 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -2,10 +2,10 @@ """Functions to solve a particle reaction problem. This module is responsible for solving a particle reaction problem stated by a -`.StateTransitionGraph` and corresponding `.GraphSettings`. The `.Solver` -classes (e.g. :class:`.CSPSolver`) generate new quantum numbers (for example -belonging to an intermediate state) and validate the decay processes with the -rules formulated by the :mod:`.conservation_rules` module. +`.MutableTransition` and corresponding `.GraphSettings`. The `.Solver` classes +(e.g. :class:`.CSPSolver`) generate new quantum numbers (for example belonging +to an intermediate state) and validate the decay processes with the rules +formulated by the :mod:`.conservation_rules` module. """ diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 0f5b61e4..981ad97f 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -5,7 +5,7 @@ - number of initial state particles - number of final state particles -The main interface is the `.StateTransitionGraph`. +The main interface is the `.MutableTransition`. """ import copy @@ -170,7 +170,7 @@ def _to_frozenset(iterable: Iterable[int]) -> FrozenSet[int]: class Topology: """Directed Feynman-like graph without edge or node properties. - Forms the underlying topology of `StateTransitionGraph`. The graphs are + Forms the underlying topology of `MutableTransition`. The graphs are directed, meaning the edges are ingoing and outgoing to specific nodes (since feynman graphs also have a time axis). Note that a `Topology` is not strictly speaking a graph from graph theory, because it allows open edges, @@ -665,7 +665,7 @@ def _cast_nodes(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: @implement_pretty_repr @define -class StateTransitionGraph(Generic[EdgeType, NodeType]): +class MutableTransition(Generic[EdgeType, NodeType]): """Graph class that resembles a frozen `.Topology` with properties. This class should contain the full information of a state transition from a @@ -680,7 +680,7 @@ class StateTransitionGraph(Generic[EdgeType, NodeType]): def compare( self, - other: "StateTransitionGraph", + other: "MutableTransition", edge_comparator: Optional[Callable[[EdgeType, EdgeType], bool]] = None, node_comparator: Optional[Callable[[NodeType, NodeType], bool]] = None, ) -> bool: diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 5385cdf5..06a434b8 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -72,7 +72,7 @@ ) from .topology import ( FrozenDict, - StateTransitionGraph, + MutableTransition, Topology, _assert_all_defined, create_isobar_topologies, @@ -138,7 +138,7 @@ class _SolutionContainer: """Defines a result of a `.ProblemSet`.""" solutions: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ] = field(factory=list) execution_info: ExecutionInfo = field(default=ExecutionInfo()) @@ -550,7 +550,7 @@ def find_solutions( # pylint: disable=too-many-branches qn_problems = [x.to_qn_problem_set() for x in problems] # Because of pickling problems of Generic classes (in this case - # StateTransitionGraph), multithreaded code has to work with + # MutableTransition), multithreaded code has to work with # QNProblemSet's and QNResult's. So the appropriate conversions # have to be done before and after temp_qn_results: List[Tuple[QNProblemSet, QNResult]] = [] @@ -653,9 +653,7 @@ def __convert_result( """ solutions = [] for solution in qn_result.solutions: - graph = StateTransitionGraph[ - ParticleWithSpin, InteractionProperties - ]( + graph = MutableTransition[ParticleWithSpin, InteractionProperties]( topology=topology, node_props={ i: create_interaction_properties(x) @@ -693,9 +691,9 @@ def _safe_wrap_list( def _match_final_state_ids( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], state_definition: Sequence[StateDefinition], -) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: +) -> MutableTransition[ParticleWithSpin, InteractionProperties]: """Temporary fix to https://github.com/ComPWA/qrules/issues/143.""" particle_names = _strip_spin(state_definition) name_to_id = {name: i for i, name in enumerate(particle_names)} @@ -704,7 +702,7 @@ def _match_final_state_ids( for i in graph.topology.outgoing_edge_ids } new_topology = graph.topology.relabel_edges(id_remapping) - return StateTransitionGraph( + return MutableTransition( new_topology, edge_props={ i: graph.edge_props[id_remapping.get(i, i)] @@ -734,7 +732,7 @@ class State: @implement_pretty_repr @frozen(order=True) class StateTransition: - """Frozen instance of a `.StateTransitionGraph` of a particle with spin.""" + """Frozen instance of a `.MutableTransition` of a particle with spin.""" topology: Topology = field(validator=instance_of(Topology)) states: FrozenDict[int, State] = field(converter=FrozenDict) @@ -748,7 +746,7 @@ def __attrs_post_init__(self) -> None: @staticmethod def from_graph( - graph: StateTransitionGraph[ParticleWithSpin, InteractionProperties], + graph: MutableTransition[ParticleWithSpin, InteractionProperties], ) -> "StateTransition": return StateTransition( topology=graph.topology, @@ -762,8 +760,8 @@ def from_graph( def to_graph( self, - ) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: - return StateTransitionGraph[ParticleWithSpin, InteractionProperties]( + ) -> MutableTransition[ParticleWithSpin, InteractionProperties]: + return MutableTransition[ParticleWithSpin, InteractionProperties]( topology=self.topology, edge_props={ i: (state.particle, state.spin_projection) @@ -836,7 +834,7 @@ def get_intermediate_particles(self) -> ParticleCollection: @staticmethod def from_graphs( graphs: Iterable[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ], formalism: str, ) -> "ReactionInfo": @@ -845,7 +843,7 @@ def from_graphs( def to_graphs( self, - ) -> List[StateTransitionGraph[ParticleWithSpin, InteractionProperties]]: + ) -> List[MutableTransition[ParticleWithSpin, InteractionProperties]]: return [transition.to_graph() for transition in self.transitions] def group_by_topology(self) -> Dict[Topology, List[StateTransition]]: diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index be7b8638..db3c322d 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -5,7 +5,7 @@ from qrules import io from qrules.particle import Particle, ParticleCollection from qrules.topology import ( - StateTransitionGraph, + MutableTransition, Topology, create_isobar_topologies, create_n_body_topology, @@ -44,10 +44,10 @@ def test_asdict_fromdict(particle_selection: ParticleCollection): def test_asdict_fromdict_reaction(reaction: ReactionInfo): - # StateTransitionGraph + # MutableTransition for graph in reaction.to_graphs(): fromdict = through_dict(graph) - assert isinstance(fromdict, StateTransitionGraph) + assert isinstance(fromdict, MutableTransition) assert graph == fromdict # ReactionInfo fromdict = through_dict(reaction) diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index c030cc66..e5228521 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -24,7 +24,7 @@ InteractionProperties, NodeQuantumNumbers, ) -from qrules.topology import Edge, StateTransitionGraph, Topology +from qrules.topology import Edge, MutableTransition, Topology @pytest.mark.parametrize( @@ -219,7 +219,7 @@ def make_ls_test_graph( ) } edge_props: Dict[int, ParticleWithSpin] = {0: (particle, 0)} - graph = StateTransitionGraph(topology, node_props, edge_props) + graph = MutableTransition(topology, node_props, edge_props) return graph @@ -237,7 +237,7 @@ def make_ls_test_graph_scrambled( ) } edge_props: Dict[int, ParticleWithSpin] = {0: (particle, 0)} - graph = StateTransitionGraph(topology, node_props, edge_props) + graph = MutableTransition(topology, node_props, edge_props) return graph @@ -341,8 +341,8 @@ def test_filter_graphs_for_interaction_qns( def _create_graph( problem_set: ProblemSet, -) -> StateTransitionGraph[ParticleWithSpin, InteractionProperties]: - return StateTransitionGraph( +) -> MutableTransition[ParticleWithSpin, InteractionProperties]: + return MutableTransition( topology=problem_set.topology, node_props=problem_set.initial_facts.node_props, edge_props=problem_set.initial_facts.edge_props, @@ -370,7 +370,7 @@ def test_edge_swap(particle_database, initial_state, final_state): problem_sets = stm.create_problem_sets() init_graphs: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ] = [] for _, problem_set_list in problem_sets.items(): init_graphs.extend([_create_graph(x) for x in problem_set_list]) @@ -418,7 +418,7 @@ def test_match_external_edges(particle_database, initial_state, final_state): problem_sets = stm.create_problem_sets() init_graphs: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ] = [] for _, problem_set_list in problem_sets.items(): init_graphs.extend([_create_graph(x) for x in problem_set_list]) @@ -509,7 +509,7 @@ def test_external_edge_identical_particle_combinatorics( match_external_edges(init_graphs) comb_graphs: List[ - StateTransitionGraph[ParticleWithSpin, InteractionProperties] + MutableTransition[ParticleWithSpin, InteractionProperties] ] = [] for group in init_graphs: comb_graphs.extend( diff --git a/tests/unit/test_transition.py b/tests/unit/test_transition.py index cdc929af..55275b92 100644 --- a/tests/unit/test_transition.py +++ b/tests/unit/test_transition.py @@ -17,7 +17,7 @@ from qrules.topology import ( # noqa: F401 Edge, FrozenDict, - StateTransitionGraph, + MutableTransition, Topology, ) from qrules.transition import State # noqa: F401 From fee55c32a7dba41476ac6ede47380891f51ca17e Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:33 +0100 Subject: [PATCH 10/34] feat: define (frozen) Transition --- src/qrules/topology.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 981ad97f..bd812ac9 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -655,6 +655,20 @@ def _attach_node_to_edges( """A `~typing.TypeVar` representing the type of node properties.""" +@implement_pretty_repr() +@frozen(order=True) +class FrozenTransition(Generic[EdgeType, NodeType]): + """Defines a frozen mapping of edge and node properties on a `Topology`.""" + + topology: Topology = field(validator=instance_of(Topology)) + edge_props: FrozenDict[int, NodeType] = field(converter=FrozenDict) + node_props: FrozenDict[int, EdgeType] = field(converter=FrozenDict) + + def __attrs_post_init__(self) -> None: + _assert_all_defined(self.topology.nodes, self.node_props) + _assert_all_defined(self.topology.edges, self.edge_props) + + def _cast_edges(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: return dict(obj) @@ -714,7 +728,6 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: self.edge_props[edge_id1] = value2 -# pyright: reportUnusedFunction=false def _assert_all_defined(items: Iterable, properties: Iterable) -> None: existing = set(items) defined = set(properties) From f050ace2c3ba322ed01e41219003cf9ea584b633 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:33 +0100 Subject: [PATCH 11/34] refactor: rename edge/node_props to states/interactions --- src/qrules/__init__.py | 2 +- src/qrules/_system_control.py | 66 +++++++++++++++---------------- src/qrules/argument_handling.py | 12 +++--- src/qrules/combinatorics.py | 8 ++-- src/qrules/io/__init__.py | 2 - src/qrules/io/_dict.py | 40 ++++--------------- src/qrules/io/_dot.py | 56 +++++++++++++------------- src/qrules/solving.py | 58 +++++++++++++-------------- src/qrules/topology.py | 50 ++++++++++++----------- src/qrules/transition.py | 46 ++++++++++----------- tests/unit/io/test_dot.py | 18 ++++----- tests/unit/io/test_io.py | 7 ++-- tests/unit/test_combinatorics.py | 4 +- tests/unit/test_system_control.py | 26 ++++++------ 14 files changed, 185 insertions(+), 210 deletions(-) diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 31226fcf..95ade360 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -239,7 +239,7 @@ def check_edge_qn_conservation() -> Set[FrozenSet[str]]: for facts_combination in initial_facts: new_facts = attrs.evolve( facts_combination, - node_props={node_id: ls_combi}, + interactions={node_id: ls_combi}, ) initial_facts_list.append(new_facts) diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index 974be04e..26935d4d 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -59,7 +59,7 @@ def create_edge_properties( def create_node_properties( - node_props: InteractionProperties, + interactions: InteractionProperties, ) -> GraphNodePropertyMap: node_qn_mapping: Dict[str, Type[NodeQuantumNumber]] = { qn_name: qn_type @@ -67,7 +67,7 @@ def create_node_properties( if not qn_name.startswith("__") } # Note using attrs.fields does not work here because init=False property_map: GraphNodePropertyMap = {} - for qn_name, value in attrs.asdict(node_props).items(): + for qn_name, value in attrs.asdict(interactions).items(): if value is None: continue if qn_name in node_qn_mapping: @@ -82,7 +82,7 @@ def create_node_properties( def create_particle( - edge_props: GraphEdgePropertyMap, particle_db: ParticleCollection + states: GraphEdgePropertyMap, particle_db: ParticleCollection ) -> ParticleWithSpin: """Create a Particle with spin projection from a qn dictionary. @@ -90,7 +90,7 @@ def create_particle( particle inside the `.ParticleCollection`. Args: - edge_props: The quantum number dictionary. + states: The quantum number dictionary. particle_db: A `.ParticleCollection` which is used to retrieve a reference `.particle` to lower the memory footprint. @@ -100,13 +100,13 @@ def create_particle( ValueError: If the edge properties do not contain spin projection info. """ - particle = particle_db.find(int(edge_props[EdgeQuantumNumbers.pid])) - if EdgeQuantumNumbers.spin_projection not in edge_props: + particle = particle_db.find(int(states[EdgeQuantumNumbers.pid])) + if EdgeQuantumNumbers.spin_projection not in states: raise ValueError( f"{GraphEdgePropertyMap.__name__} does not contain a spin" " projection" ) - spin_projection = edge_props[EdgeQuantumNumbers.spin_projection] + spin_projection = states[EdgeQuantumNumbers.spin_projection] return (particle, spin_projection) @@ -150,9 +150,9 @@ class InteractionDeterminator(ABC): @abstractmethod def check( self, - in_edge_props: List[ParticleWithSpin], - out_edge_props: List[ParticleWithSpin], - node_props: InteractionProperties, + in_states: List[ParticleWithSpin], + out_states: List[ParticleWithSpin], + interactions: InteractionProperties, ) -> List[InteractionType]: pass @@ -162,12 +162,12 @@ class GammaCheck(InteractionDeterminator): def check( self, - in_edge_props: List[ParticleWithSpin], - out_edge_props: List[ParticleWithSpin], - node_props: InteractionProperties, + in_states: List[ParticleWithSpin], + out_states: List[ParticleWithSpin], + interactions: InteractionProperties, ) -> List[InteractionType]: int_types = list(InteractionType) - for particle, _ in in_edge_props + out_edge_props: + for particle, _ in in_states + out_states: if "gamma" in particle.name: int_types = [InteractionType.EM] break @@ -179,12 +179,12 @@ class LeptonCheck(InteractionDeterminator): def check( self, - in_edge_props: List[ParticleWithSpin], - out_edge_props: List[ParticleWithSpin], - node_props: InteractionProperties, + in_states: List[ParticleWithSpin], + out_states: List[ParticleWithSpin], + interactions: InteractionProperties, ) -> List[InteractionType]: node_interaction_types = list(InteractionType) - for particle, _ in in_edge_props + out_edge_props: + for particle, _ in in_states + out_states: if particle.is_lepton(): if particle.name.startswith("nu("): node_interaction_types = [InteractionType.WEAK] @@ -235,14 +235,14 @@ def _remove_qns_from_graph( # pylint: disable=too-many-branches graph: MutableTransition[ParticleWithSpin, InteractionProperties], qn_list: Set[Type[NodeQuantumNumber]], ) -> MutableTransition[ParticleWithSpin, InteractionProperties]: - new_node_props = {} + new_interactions = {} for node_id in graph.topology.nodes: - node_props = graph.node_props[node_id] - new_node_props[node_id] = attrs.evolve( - node_props, **{x.__name__: None for x in qn_list} + interactions = graph.interactions[node_id] + new_interactions[node_id] = attrs.evolve( + interactions, **{x.__name__: None for x in qn_list} ) - return attrs.evolve(graph, node_props=new_node_props) + return attrs.evolve(graph, interactions=new_interactions) def _check_equal_ignoring_qns( @@ -254,13 +254,13 @@ def _check_equal_ignoring_qns( if not isinstance(ref_graph, MutableTransition): raise TypeError("Reference graph has to be of type MutableTransition") found_graph = None - node_comparator = NodePropertyComparator(ignored_qn_list) + interaction_comparator = NodePropertyComparator(ignored_qn_list) for graph in solutions: if isinstance(graph, MutableTransition): if graph.compare( ref_graph, - edge_comparator=lambda e1, e2: e1 == e2, - node_comparator=node_comparator, + state_comparator=lambda e1, e2: e1 == e2, + interaction_comparator=interaction_comparator, ): found_graph = graph break @@ -278,14 +278,14 @@ def __init__( def __call__( self, - node_props1: InteractionProperties, - node_props2: InteractionProperties, + interactions1: InteractionProperties, + interactions2: InteractionProperties, ) -> bool: return attrs.evolve( - node_props1, + interactions1, **{x.__name__: None for x in self.__ignored_qn_list}, ) == attrs.evolve( - node_props2, + interactions2, **{x.__name__: None for x in self.__ignored_qn_list}, ) @@ -365,7 +365,7 @@ def check( return False for i in node_ids: if ( - getattr(graph.node_props[i], interaction_qn.__name__) + getattr(graph.interactions[i], interaction_qn.__name__) not in allowed_values ): return False @@ -382,8 +382,8 @@ def _find_node_ids_with_ingoing_particle_name( found_node_ids = [] for node_id in topology.nodes: for edge_id in topology.get_edge_ids_ingoing_to_node(node_id): - edge_props = graph.edge_props[edge_id] - edge_particle_name = edge_props[0].name + states = graph.states[edge_id] + edge_particle_name = states[0].name if str(ingoing_particle_name) in str(edge_particle_name): found_node_ids.append(node_id) break diff --git a/src/qrules/argument_handling.py b/src/qrules/argument_handling.py index 8c200f19..8b1a25f9 100644 --- a/src/qrules/argument_handling.py +++ b/src/qrules/argument_handling.py @@ -93,11 +93,11 @@ def wrapper(props: GraphElementPropertyMap) -> bool: def _sequence_input_check(func: Callable) -> Callable[[Sequence], bool]: - def wrapper(edge_props_list: Sequence[Any]) -> bool: - if not isinstance(edge_props_list, (list, tuple)): + def wrapper(states_list: Sequence[Any]) -> bool: + if not isinstance(states_list, (list, tuple)): raise TypeError("Rule evaluated with invalid argument type...") - return all(func(x) for x in edge_props_list) + return all(func(x) for x in states_list) return wrapper @@ -170,11 +170,11 @@ def __call__( def _sequence_arg_builder(func: Callable) -> Callable[[Sequence], List[Any]]: - def wrapper(edge_props_list: Sequence[Any]) -> List[Any]: - if not isinstance(edge_props_list, (list, tuple)): + def wrapper(states_list: Sequence[Any]) -> List[Any]: + if not isinstance(states_list, (list, tuple)): raise TypeError("Rule evaluated with invalid argument type...") - return [func(x) for x in edge_props_list if x] + return [func(x) for x in states_list if x] return wrapper diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 567f08db..e210e973 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -40,8 +40,8 @@ @implement_pretty_repr @frozen class InitialFacts: - edge_props: Dict[int, ParticleWithSpin] = field(factory=dict) - node_props: Dict[int, InteractionProperties] = field(factory=dict) + states: Dict[int, ParticleWithSpin] = field(factory=dict) + interactions: Dict[int, InteractionProperties] = field(factory=dict) class _KinematicRepresentation: @@ -268,7 +268,7 @@ def embed_in_list(some_list: List[Any]) -> List[List[Any]]: kinematic_permutation, particle_db ) edge_initial_facts.extend( - [InitialFacts(edge_props=x) for x in spin_permutations] + [InitialFacts(states=x) for x in spin_permutations] ) return edge_initial_facts @@ -560,4 +560,4 @@ def _create_edge_id_particle_mapping( graph: MutableTransition[ParticleWithSpin, InteractionProperties], edge_ids: Iterable[int], ) -> Dict[int, str]: - return {i: graph.edge_props[i][0].name for i in edge_ids} + return {i: graph.states[i][0].name for i in edge_ids} diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index e4059b83..4f46a4fb 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -54,8 +54,6 @@ def fromdict(definition: dict) -> object: return _dict.build_reaction_info(definition) if keys == {"topology", "states", "interactions"}: return _dict.build_state_transition(definition) - if keys == {"topology", "edge_props", "node_props"}: - return _dict.build_stg(definition) if keys == __REQUIRED_TOPOLOGY_FIELDS: return _dict.build_topology(definition) raise NotImplementedError(f"Could not determine type from keys {keys}") diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index 2dbb3ec9..7d06674b 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -37,25 +37,25 @@ def from_stg( graph: MutableTransition[ParticleWithSpin, InteractionProperties] ) -> dict: topology = graph.topology - edge_props_def = {} + states_def = {} for i in topology.edges: - particle, spin_projection = graph.edge_props[i] + particle, spin_projection = graph.states[i] if isinstance(spin_projection, float) and spin_projection.is_integer(): spin_projection = int(spin_projection) - edge_props_def[i] = { + states_def[i] = { "particle": from_particle(particle), "spin_projection": spin_projection, } - node_props_def = {} + interactions_def = {} for i in topology.nodes: - node_prop = graph.node_props[i] - node_props_def[i] = attrs.asdict( + node_prop = graph.interactions[i] + interactions_def[i] = attrs.asdict( node_prop, filter=lambda a, v: a.init and a.default != v ) return { "topology": from_topology(topology), - "edge_props": edge_props_def, - "node_props": node_props_def, + "states": states_def, + "interactions": interactions_def, } @@ -118,30 +118,6 @@ def build_reaction_info(definition: dict) -> ReactionInfo: return ReactionInfo(transitions, formalism=definition["formalism"]) -def build_stg( - definition: dict, -) -> MutableTransition[ParticleWithSpin, InteractionProperties]: - topology = build_topology(definition["topology"]) - edge_props_def: Dict[int, dict] = definition["edge_props"] - edge_props: Dict[int, ParticleWithSpin] = {} - for i, edge_def in edge_props_def.items(): - particle = build_particle(edge_def["particle"]) - spin_projection = float(edge_def["spin_projection"]) - if spin_projection.is_integer(): - spin_projection = int(spin_projection) - edge_props[int(i)] = (particle, spin_projection) - node_props_def: Dict[int, dict] = definition["node_props"] - node_props = { - int(i): InteractionProperties(**node_def) - for i, node_def in node_props_def.items() - } - return MutableTransition( - topology=topology, - edge_props=edge_props, - node_props=node_props, - ) - - def build_state_transition(definition: dict) -> StateTransition: topology = build_topology(definition["topology"]) states = { diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 8be48c69..f0f396f1 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -279,8 +279,8 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals from_node, to_node, label=label, graphviz_attrs=edge_style ) if isinstance(graph, ProblemSet): - node_props = graph.solving_settings.node_settings - for node_id, settings in node_props.items(): + node_settings = graph.solving_settings.node_settings + for node_id, settings in node_settings.items(): node_label = "" if render_node: node_label = __node_label(settings) @@ -295,7 +295,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals int, InteractionProperties ] = graph.interactions else: - interactions = {i: graph.node_props[i] for i in topology.nodes} + interactions = {i: graph.interactions[i] for i in topology.nodes} for node_id, node_prop in interactions.items(): node_label = "" if render_node: @@ -347,11 +347,11 @@ def __get_edge_label( edge_setting = graph.edge_settings.get(edge_id) return ___render_edge_with_id(edge_id, edge_setting, render_edge_id) if isinstance(graph, InitialFacts): - initial_fact = graph.edge_props.get(edge_id) + initial_fact = graph.states.get(edge_id) return ___render_edge_with_id(edge_id, initial_fact, render_edge_id) if isinstance(graph, ProblemSet): edge_setting = graph.solving_settings.edge_settings.get(edge_id) - initial_fact = graph.initial_facts.edge_props.get(edge_id) + initial_fact = graph.initial_facts.states.get(edge_id) edge_property: Optional[Union[EdgeSettings, ParticleWithSpin]] = None if edge_setting: edge_property = edge_setting @@ -361,7 +361,7 @@ def __get_edge_label( if isinstance(graph, StateTransition): graph = graph.to_graph() if isinstance(graph, MutableTransition): - edge_prop = graph.edge_props[edge_id] + edge_prop = graph.states[edge_id] return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) if isinstance(graph, Topology): if render_edge_id: @@ -483,7 +483,7 @@ def _get_particle_graphs( transition = transition.to_graph() if any( transition.compare( - other, edge_comparator=lambda e1, e2: e1[0] == e2 + other, state_comparator=lambda e1, e2: e1[0] == e2 ) for other in inventory ): @@ -493,7 +493,7 @@ def _get_particle_graphs( inventory = sorted( inventory, key=lambda g: [ - g.edge_props[i].mass for i in g.topology.intermediate_edge_ids + g.states[i].mass for i in g.topology.intermediate_edge_ids ], ) return inventory @@ -504,22 +504,22 @@ def _strip_projections( ) -> MutableTransition[Particle, InteractionProperties]: if isinstance(graph, StateTransition): graph = graph.to_graph() - new_edge_props = {} + new_states = {} for edge_id in graph.topology.edges: - edge_props = graph.edge_props[edge_id] - if edge_props: - new_edge_props[edge_id] = edge_props[0] - new_node_props = {} + states = graph.states[edge_id] + if states: + new_states[edge_id] = states[0] + new_interactions = {} for node_id in graph.topology.nodes: - node_props = graph.node_props[node_id] - if node_props: - new_node_props[node_id] = attrs.evolve( - node_props, l_projection=None, s_projection=None + interactions = graph.interactions[node_id] + if interactions: + new_interactions[node_id] = attrs.evolve( + interactions, l_projection=None, s_projection=None ) return MutableTransition[Particle, InteractionProperties]( topology=graph.topology, - node_props=new_node_props, - edge_props=new_edge_props, + interactions=new_interactions, + states=new_states, ) @@ -542,8 +542,8 @@ def merge_into( "Cannot merge graphs that don't have the same edge IDs" ) for i in graph.topology.edges: - particle = graph.edge_props[i] - other_particles = merged_graph.edge_props[i] + particle = graph.states[i] + other_particles = merged_graph.states[i] if particle not in other_particles: other_particles += particle @@ -558,11 +558,11 @@ def is_same_shape( for edge_id in ( graph.topology.incoming_edge_ids | graph.topology.outgoing_edge_ids ): - edge_prop = merged_graph.edge_props[edge_id] + edge_prop = merged_graph.states[edge_id] if len(edge_prop) != 1: return False other_particle = next(iter(edge_prop)) - if other_particle != graph.edge_props[edge_id]: + if other_particle != graph.states[edge_id]: return False return True @@ -578,17 +578,17 @@ def is_same_shape( append_to_inventory = False break if append_to_inventory: - new_edge_props = { - edge_id: ParticleCollection({graph.edge_props[edge_id]}) + new_states = { + edge_id: ParticleCollection({graph.states[edge_id]}) for edge_id in graph.topology.edges } inventory.append( MutableTransition[ParticleCollection, InteractionProperties]( topology=graph.topology, - node_props={ - i: graph.node_props[i] for i in graph.topology.nodes + interactions={ + i: graph.interactions[i] for i in graph.topology.nodes }, - edge_props=new_edge_props, + states=new_states, ) ) return inventory diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 7fcf851a..c4566495 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -99,8 +99,8 @@ class GraphSettings: @implement_pretty_repr @define class GraphElementProperties: - edge_props: Dict[int, GraphEdgePropertyMap] = field(factory=dict) - node_props: Dict[int, GraphNodePropertyMap] = field(factory=dict) + states: Dict[int, GraphEdgePropertyMap] = field(factory=dict) + interactions: Dict[int, GraphNodePropertyMap] = field(factory=dict) @implement_pretty_repr @@ -326,12 +326,12 @@ def _create_node_variables( ) -> Dict[Type[NodeQuantumNumber], Scalar]: """Create variables for the quantum numbers of the specified node.""" variables = {} - if node_id in problem_set.initial_facts.node_props: - node_props = problem_set.initial_facts.node_props[node_id] - variables = node_props + if node_id in problem_set.initial_facts.interactions: + interactions = problem_set.initial_facts.interactions[node_id] + variables = interactions for qn_type in qn_list: - if qn_type in node_props: - variables[qn_type] = node_props[qn_type] + if qn_type in interactions: + variables[qn_type] = interactions[qn_type] return variables def _create_edge_variables( @@ -346,12 +346,12 @@ def _create_edge_variables( """ variables = [] for edge_id in edge_ids: - if edge_id in problem_set.initial_facts.edge_props: - edge_props = problem_set.initial_facts.edge_props[edge_id] + if edge_id in problem_set.initial_facts.states: + states = problem_set.initial_facts.states[edge_id] edge_vars = {} for qn_type in qn_list: - if qn_type in edge_props: - edge_vars[qn_type] = edge_props[qn_type] + if qn_type in states: + edge_vars[qn_type] = states[qn_type] variables.append(edge_vars) return variables @@ -444,8 +444,8 @@ def _create_variable_containers( return QNResult( [ QuantumNumberSolution( - edge_quantum_numbers=problem_set.initial_facts.edge_props, - node_quantum_numbers=problem_set.initial_facts.node_props, + edge_quantum_numbers=problem_set.initial_facts.states, + node_quantum_numbers=problem_set.initial_facts.interactions, ) ], ) @@ -548,8 +548,8 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: else: full_particle_solutions = [ QuantumNumberSolution( - node_quantum_numbers=problem_set.initial_facts.node_props, - edge_quantum_numbers=problem_set.initial_facts.edge_props, + node_quantum_numbers=problem_set.initial_facts.interactions, + edge_quantum_numbers=problem_set.initial_facts.states, ) ] @@ -560,17 +560,17 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: # and combine results result = QNResult() for full_particle_solution in full_particle_solutions: - node_props = full_particle_solution.node_quantum_numbers - edge_props = full_particle_solution.edge_quantum_numbers - node_props.update(problem_set.initial_facts.node_props) - edge_props.update(problem_set.initial_facts.edge_props) + interactions = full_particle_solution.node_quantum_numbers + states = full_particle_solution.edge_quantum_numbers + interactions.update(problem_set.initial_facts.interactions) + states.update(problem_set.initial_facts.states) result.extend( validate_full_solution( QNProblemSet( topology=problem_set.topology, initial_facts=GraphElementProperties( - node_props=node_props, - edge_props=edge_props, + interactions=interactions, + states=states, ), solving_settings=GraphSettings( node_settings={ @@ -755,11 +755,11 @@ def __create_node_variables( {}, ) - if node_id in problem_set.initial_facts.node_props: - node_props = problem_set.initial_facts.node_props[node_id] + if node_id in problem_set.initial_facts.interactions: + interactions = problem_set.initial_facts.interactions[node_id] for qn_type in qn_list: - if qn_type in node_props: - variables[1].update({qn_type: node_props[qn_type]}) + if qn_type in interactions: + variables[1].update({qn_type: interactions[qn_type]}) else: node_settings = problem_set.solving_settings.node_settings[node_id] for qn_type in qn_list: @@ -793,12 +793,12 @@ def __create_edge_variables( for edge_id in edge_ids: variables[1][edge_id] = {} - if edge_id in problem_set.initial_facts.edge_props: - edge_props = problem_set.initial_facts.edge_props[edge_id] + if edge_id in problem_set.initial_facts.states: + states = problem_set.initial_facts.states[edge_id] for qn_type in qn_list: - if qn_type in edge_props: + if qn_type in states: variables[1][edge_id].update( - {qn_type: edge_props[qn_type]} + {qn_type: states[qn_type]} ) else: edge_settings = problem_set.solving_settings.edge_settings[ diff --git a/src/qrules/topology.py b/src/qrules/topology.py index bd812ac9..6c3b288b 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -655,25 +655,25 @@ def _attach_node_to_edges( """A `~typing.TypeVar` representing the type of node properties.""" -@implement_pretty_repr() +@implement_pretty_repr @frozen(order=True) class FrozenTransition(Generic[EdgeType, NodeType]): """Defines a frozen mapping of edge and node properties on a `Topology`.""" topology: Topology = field(validator=instance_of(Topology)) - edge_props: FrozenDict[int, NodeType] = field(converter=FrozenDict) - node_props: FrozenDict[int, EdgeType] = field(converter=FrozenDict) + states: FrozenDict[int, EdgeType] = field(converter=FrozenDict) + interactions: FrozenDict[int, NodeType] = field(converter=FrozenDict) def __attrs_post_init__(self) -> None: - _assert_all_defined(self.topology.nodes, self.node_props) - _assert_all_defined(self.topology.edges, self.edge_props) + _assert_all_defined(self.topology.nodes, self.interactions) + _assert_all_defined(self.topology.edges, self.states) -def _cast_edges(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: +def _cast_states(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: return dict(obj) -def _cast_nodes(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: +def _cast_interactions(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: return dict(obj) @@ -689,27 +689,29 @@ class MutableTransition(Generic[EdgeType, NodeType]): """ topology: Topology = field(validator=instance_of(Topology)) - node_props: Dict[int, NodeType] = field(converter=_cast_nodes) - edge_props: Dict[int, EdgeType] = field(converter=_cast_edges) + states: Dict[int, EdgeType] = field(converter=_cast_states) + interactions: Dict[int, NodeType] = field(converter=_cast_interactions) def compare( self, other: "MutableTransition", - edge_comparator: Optional[Callable[[EdgeType, EdgeType], bool]] = None, - node_comparator: Optional[Callable[[NodeType, NodeType], bool]] = None, + state_comparator: Optional[ + Callable[[EdgeType, EdgeType], bool] + ] = None, + interaction_comparator: Optional[ + Callable[[NodeType, NodeType], bool] + ] = None, ) -> bool: if self.topology != other.topology: return False - if edge_comparator is not None: + if state_comparator is not None: for i in self.topology.edges: - if not edge_comparator( - self.edge_props[i], other.edge_props[i] - ): + if not state_comparator(self.states[i], other.states[i]): return False - if node_comparator is not None: + if interaction_comparator is not None: for i in self.topology.nodes: - if not node_comparator( - self.node_props[i], other.node_props[i] + if not interaction_comparator( + self.interactions[i], other.interactions[i] ): return False return True @@ -718,14 +720,14 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: self.topology = self.topology.swap_edges(edge_id1, edge_id2) value1: Optional[EdgeType] = None value2: Optional[EdgeType] = None - if edge_id1 in self.edge_props: - value1 = self.edge_props.pop(edge_id1) - if edge_id2 in self.edge_props: - value2 = self.edge_props.pop(edge_id2) + if edge_id1 in self.states: + value1 = self.states.pop(edge_id1) + if edge_id2 in self.states: + value2 = self.states.pop(edge_id2) if value1 is not None: - self.edge_props[edge_id2] = value1 + self.states[edge_id2] = value1 if value2 is not None: - self.edge_props[edge_id1] = value2 + self.states[edge_id1] = value2 def _assert_all_defined(items: Iterable, properties: Iterable) -> None: diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 06a434b8..4642f78e 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -184,18 +184,18 @@ class ProblemSet: solving_settings: GraphSettings def to_qn_problem_set(self) -> QNProblemSet: - node_props = { + interactions = { k: create_node_properties(v) - for k, v in self.initial_facts.node_props.items() + for k, v in self.initial_facts.interactions.items() } - edge_props = { + states = { k: create_edge_properties(v[0], v[1]) - for k, v in self.initial_facts.edge_props.items() + for k, v in self.initial_facts.states.items() } return QNProblemSet( topology=self.topology, initial_facts=GraphElementProperties( - node_props=node_props, edge_props=edge_props + interactions=interactions, states=states ), solving_settings=self.solving_settings, ) @@ -476,24 +476,24 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: interaction_types: List[InteractionType] = [] out_edge_ids = topology.get_edge_ids_outgoing_from_node(node_id) in_edge_ids = topology.get_edge_ids_outgoing_from_node(node_id) - in_edge_props = [ - initial_facts.edge_props[edge_id] + in_states = [ + initial_facts.states[edge_id] for edge_id in [ x for x in in_edge_ids if x in initial_state_edges ] ] - out_edge_props = [ - initial_facts.edge_props[edge_id] + out_states = [ + initial_facts.states[edge_id] for edge_id in [ x for x in out_edge_ids if x in final_state_edges ] ] - node_props = InteractionProperties() - if node_id in initial_facts.node_props: - node_props = initial_facts.node_props[node_id] + interactions = InteractionProperties() + if node_id in initial_facts.interactions: + interactions = initial_facts.interactions[node_id] for int_det in self.interaction_determinators: determined_interactions = int_det.check( - in_edge_props, out_edge_props, node_props + in_states, out_states, interactions ) if interaction_types: interaction_types = list( @@ -655,11 +655,11 @@ def __convert_result( for solution in qn_result.solutions: graph = MutableTransition[ParticleWithSpin, InteractionProperties]( topology=topology, - node_props={ + interactions={ i: create_interaction_properties(x) for i, x in solution.node_quantum_numbers.items() }, - edge_props={ + states={ i: create_particle(x, self.__particles) for i, x in solution.edge_quantum_numbers.items() }, @@ -698,17 +698,17 @@ def _match_final_state_ids( particle_names = _strip_spin(state_definition) name_to_id = {name: i for i, name in enumerate(particle_names)} id_remapping = { - name_to_id[graph.edge_props[i][0].name]: i + name_to_id[graph.states[i][0].name]: i for i in graph.topology.outgoing_edge_ids } new_topology = graph.topology.relabel_edges(id_remapping) return MutableTransition( new_topology, - edge_props={ - i: graph.edge_props[id_remapping.get(i, i)] + states={ + i: graph.states[id_remapping.get(i, i)] for i in graph.topology.edges }, - node_props={i: graph.node_props[i] for i in graph.topology.nodes}, + interactions={i: graph.interactions[i] for i in graph.topology.nodes}, ) @@ -751,10 +751,10 @@ def from_graph( return StateTransition( topology=graph.topology, states=FrozenDict( - {i: State(*graph.edge_props[i]) for i in graph.topology.edges} + {i: State(*graph.states[i]) for i in graph.topology.edges} ), interactions=FrozenDict( - {i: graph.node_props[i] for i in graph.topology.nodes} + {i: graph.interactions[i] for i in graph.topology.nodes} ), ) @@ -763,11 +763,11 @@ def to_graph( ) -> MutableTransition[ParticleWithSpin, InteractionProperties]: return MutableTransition[ParticleWithSpin, InteractionProperties]( topology=self.topology, - edge_props={ + states={ i: (state.particle, state.spin_projection) for i, state in self.states.items() }, - node_props=self.interactions, + interactions=self.interactions, ) @property diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index 79257ff3..aca343ca 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -215,7 +215,7 @@ def test_collapse_graphs( graph = next(iter(collapsed_graphs)) edge_id = next(iter(graph.topology.intermediate_edge_ids)) f_resonances = pdg.filter(lambda p: p.name in ["f(0)(980)", "f(0)(1500)"]) - intermediate_states = graph.edge_props[edge_id] + intermediate_states = graph.states[edge_id] assert isinstance(intermediate_states, ParticleCollection) assert intermediate_states == f_resonances @@ -226,11 +226,11 @@ def test_get_particle_graphs( pdg = particle_database graphs = _get_particle_graphs(reaction.to_graphs()) assert len(graphs) == 2 - assert graphs[0].edge_props[3] == pdg["f(0)(980)"] - assert graphs[1].edge_props[3] == pdg["f(0)(1500)"] + assert graphs[0].states[3] == pdg["f(0)(980)"] + assert graphs[1].states[3] == pdg["f(0)(1500)"] assert len(graphs[0].topology.edges) == 5 for i in range(-1, 3): - assert graphs[0].edge_props[i] is graphs[1].edge_props[i] + assert graphs[0].states[i] is graphs[1].states[i] def test_strip_projections(): @@ -254,8 +254,8 @@ def test_strip_projections(): assert transition.interactions[1].l_projection == 0 stripped_transition = _strip_projections(transition) # type: ignore[arg-type] - assert stripped_transition.edge_props[3].name == resonance - assert stripped_transition.node_props[0].s_projection is None - assert stripped_transition.node_props[0].l_projection is None - assert stripped_transition.node_props[1].s_projection is None - assert stripped_transition.node_props[1].l_projection is None + assert stripped_transition.states[3].name == resonance + assert stripped_transition.interactions[0].s_projection is None + assert stripped_transition.interactions[0].l_projection is None + assert stripped_transition.interactions[1].s_projection is None + assert stripped_transition.interactions[1].l_projection is None diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index db3c322d..f34b8b6e 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -5,12 +5,11 @@ from qrules import io from qrules.particle import Particle, ParticleCollection from qrules.topology import ( - MutableTransition, Topology, create_isobar_topologies, create_n_body_topology, ) -from qrules.transition import ReactionInfo +from qrules.transition import ReactionInfo, StateTransition def through_dict(instance): @@ -47,8 +46,8 @@ def test_asdict_fromdict_reaction(reaction: ReactionInfo): # MutableTransition for graph in reaction.to_graphs(): fromdict = through_dict(graph) - assert isinstance(fromdict, MutableTransition) - assert graph == fromdict + assert isinstance(fromdict, StateTransition) + assert graph == fromdict.to_graph() # ReactionInfo fromdict = through_dict(reaction) assert isinstance(fromdict, ReactionInfo) diff --git a/tests/unit/test_combinatorics.py b/tests/unit/test_combinatorics.py index a2339770..258eda37 100644 --- a/tests/unit/test_combinatorics.py +++ b/tests/unit/test_combinatorics.py @@ -98,14 +98,14 @@ def test_constructor(self): def test_from_topology(self, three_body_decay: Topology): pi0 = ("pi0", [0]) gamma = ("gamma", [-1, 1]) - edge_props = { + states = { -1: ("J/psi", [-1, +1]), 0: pi0, 1: pi0, 2: gamma, } kinematic_representation1 = _get_kinematic_representation( - three_body_decay, edge_props + three_body_decay, states ) assert kinematic_representation1.initial_state == [ ["J/psi"], diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index e5228521..8bab192c 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -212,14 +212,14 @@ def make_ls_test_graph( nodes={0}, edges={0: Edge(None, 0)}, ) - node_props = { + interactions = { 0: InteractionProperties( s_magnitude=coupled_spin_magnitude, l_magnitude=angular_momentum_magnitude, ) } - edge_props: Dict[int, ParticleWithSpin] = {0: (particle, 0)} - graph = MutableTransition(topology, node_props, edge_props) + states: Dict[int, ParticleWithSpin] = {0: (particle, 0)} + graph = MutableTransition(topology, states, interactions) return graph @@ -230,14 +230,14 @@ def make_ls_test_graph_scrambled( nodes={0}, edges={0: Edge(None, 0)}, ) - node_props = { + interactions = { 0: InteractionProperties( l_magnitude=angular_momentum_magnitude, s_magnitude=coupled_spin_magnitude, ) } - edge_props: Dict[int, ParticleWithSpin] = {0: (particle, 0)} - graph = MutableTransition(topology, node_props, edge_props) + states: Dict[int, ParticleWithSpin] = {0: (particle, 0)} + graph = MutableTransition(topology, states, interactions) return graph @@ -325,7 +325,7 @@ def test_filter_graphs_for_interaction_qns( tempgraph = make_ls_test_graph(value[1][0], value[1][1], pi0) tempgraph = attrs.evolve( tempgraph, - edge_props={ + states={ 0: ( Particle(name=value[0], pid=0, mass=1.0, spin=1.0), 0.0, @@ -344,8 +344,8 @@ def _create_graph( ) -> MutableTransition[ParticleWithSpin, InteractionProperties]: return MutableTransition( topology=problem_set.topology, - node_props=problem_set.initial_facts.node_props, - edge_props=problem_set.initial_facts.edge_props, + interactions=problem_set.initial_facts.interactions, + states=problem_set.initial_facts.states, ) @@ -382,15 +382,15 @@ def test_edge_swap(particle_database, initial_state, final_state): edge_keys = list(ref_mapping.keys()) edge1 = edge_keys[0] edge1_val = graph.topology.edges[edge1] - edge1_props = deepcopy(graph.edge_props[edge1]) + edge1_props = deepcopy(graph.states[edge1]) edge2 = edge_keys[1] edge2_val = graph.topology.edges[edge2] - edge2_props = deepcopy(graph.edge_props[edge2]) + edge2_props = deepcopy(graph.states[edge2]) graph.swap_edges(edge1, edge2) assert graph.topology.edges[edge1] == edge2_val assert graph.topology.edges[edge2] == edge1_val - assert graph.edge_props[edge1] == edge2_props - assert graph.edge_props[edge2] == edge1_props + assert graph.states[edge1] == edge2_props + assert graph.states[edge2] == edge1_props @pytest.mark.parametrize( From 0bcf9c449172d2dfcbeacd19929191d44f2354a5 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:34 +0100 Subject: [PATCH 12/34] refactor: rename GraphSettings attrs to states/interactions --- src/qrules/__init__.py | 4 ++-- src/qrules/io/_dot.py | 6 +++--- src/qrules/solving.py | 22 ++++++++++------------ src/qrules/transition.py | 8 ++++---- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 95ade360..478a3d02 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -140,11 +140,11 @@ def _check_violations( topology=topology, initial_facts=facts, solving_settings=GraphSettings( - node_settings={ + interactions={ i: NodeSettings(conservation_rules=rules) for i, rules in node_rules.items() }, - edge_settings={ + states={ i: EdgeSettings(conservation_rules=rules) for i, rules in edge_rules.items() }, diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index f0f396f1..2a2c0d2e 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -279,7 +279,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals from_node, to_node, label=label, graphviz_attrs=edge_style ) if isinstance(graph, ProblemSet): - node_settings = graph.solving_settings.node_settings + node_settings = graph.solving_settings.interactions for node_id, settings in node_settings.items(): node_label = "" if render_node: @@ -344,13 +344,13 @@ def __get_edge_label( render_edge_id: bool, ) -> str: if isinstance(graph, GraphSettings): - edge_setting = graph.edge_settings.get(edge_id) + edge_setting = graph.states.get(edge_id) return ___render_edge_with_id(edge_id, edge_setting, render_edge_id) if isinstance(graph, InitialFacts): initial_fact = graph.states.get(edge_id) return ___render_edge_with_id(edge_id, initial_fact, render_edge_id) if isinstance(graph, ProblemSet): - edge_setting = graph.solving_settings.edge_settings.get(edge_id) + edge_setting = graph.solving_settings.states.get(edge_id) initial_fact = graph.initial_facts.states.get(edge_id) edge_property: Optional[Union[EdgeSettings, ParticleWithSpin]] = None if edge_setting: diff --git a/src/qrules/solving.py b/src/qrules/solving.py index c4566495..08de95e3 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -92,8 +92,8 @@ class NodeSettings: @implement_pretty_repr @define class GraphSettings: - edge_settings: Dict[int, EdgeSettings] = field(factory=dict) - node_settings: Dict[int, NodeSettings] = field(factory=dict) + states: Dict[int, EdgeSettings] = field(factory=dict) + interactions: Dict[int, NodeSettings] = field(factory=dict) @implement_pretty_repr @@ -380,7 +380,7 @@ def _create_variable_containers( for ( edge_id, edge_settings, - ) in problem_set.solving_settings.edge_settings.items(): + ) in problem_set.solving_settings.states.items(): edge_rules = edge_settings.conservation_rules for edge_rule in edge_rules: # get the needed qns for this conservation law @@ -407,7 +407,7 @@ def _create_variable_containers( for ( node_id, node_settings, - ) in problem_set.solving_settings.node_settings.items(): + ) in problem_set.solving_settings.interactions.items(): node_rules = node_settings.conservation_rules for rule in node_rules: # get the needed qns for this conservation law @@ -573,11 +573,11 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: states=states, ), solving_settings=GraphSettings( - node_settings={ + interactions={ i: NodeSettings(conservation_rules=rules) for i, rules in node_not_executed_rules.items() }, - edge_settings={ + states={ i: EdgeSettings(conservation_rules=rules) for i, rules in edge_not_executed_rules.items() }, @@ -639,7 +639,7 @@ def get_rules_by_priority( arg_handler = RuleArgumentHandler() for edge_id in problem_set.topology.edges: - edge_settings = problem_set.solving_settings.edge_settings[edge_id] + edge_settings = problem_set.solving_settings.states[edge_id] for rule in get_rules_by_priority(edge_settings): variable_mapping = _VariableContainer() # from cons law and graph determine needed var lists @@ -673,7 +673,7 @@ def get_rules_by_priority( for node_id in problem_set.topology.nodes: for rule in get_rules_by_priority( - problem_set.solving_settings.node_settings[node_id] + problem_set.solving_settings.interactions[node_id] ): variable_mapping = _VariableContainer() # from cons law and graph determine needed var lists @@ -761,7 +761,7 @@ def __create_node_variables( if qn_type in interactions: variables[1].update({qn_type: interactions[qn_type]}) else: - node_settings = problem_set.solving_settings.node_settings[node_id] + node_settings = problem_set.solving_settings.interactions[node_id] for qn_type in qn_list: var_info = (node_id, qn_type) if qn_type in node_settings.qn_domains: @@ -801,9 +801,7 @@ def __create_edge_variables( {qn_type: states[qn_type]} ) else: - edge_settings = problem_set.solving_settings.edge_settings[ - edge_id - ] + edge_settings = problem_set.solving_settings.states[edge_id] for qn_type in qn_list: var_info = (edge_id, qn_type) if qn_type in edge_settings.qn_domains: diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 4642f78e..94c31fb3 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -217,7 +217,7 @@ def calculate_strength( ) for problem_set in problem_sets: strength = calculate_strength( - problem_set.solving_settings.node_settings + problem_set.solving_settings.interactions ) strength_sorted_problem_sets[strength].append(problem_set) return strength_sorted_problem_sets @@ -464,11 +464,11 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: graph_settings: List[GraphSettings] = [ GraphSettings( - edge_settings={ + states={ edge_id: create_edge_settings(edge_id) for edge_id in topology.edges }, - node_settings={}, + interactions={}, ) ] @@ -515,7 +515,7 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: for temp_setting in temp_graph_settings: for int_type in interaction_types: updated_setting = deepcopy(temp_setting) - updated_setting.node_settings[node_id] = deepcopy( + updated_setting.interactions[node_id] = deepcopy( self.interaction_type_settings[int_type][1] ) graph_settings.append(updated_setting) From ba6e385dc5c308d05b6024c52b00f768eceb7b47 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:34 +0100 Subject: [PATCH 13/34] refactor: rename QuantumNumberSolution attrs to states/interactions --- src/qrules/solving.py | 40 +++++++++++++++++----------------------- src/qrules/transition.py | 4 ++-- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 08de95e3..ee0cf70d 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -126,8 +126,8 @@ class QNProblemSet: @implement_pretty_repr @frozen class QuantumNumberSolution: - node_quantum_numbers: Dict[int, GraphNodePropertyMap] - edge_quantum_numbers: Dict[int, GraphEdgePropertyMap] + states: Dict[int, GraphEdgePropertyMap] + interactions: Dict[int, GraphNodePropertyMap] def _convert_violated_rules_to_names( @@ -259,26 +259,26 @@ def _merge_particle_candidates_with_solutions( current_new_solutions = [solution] for int_edge_id in intermediate_edges: particle_edges = __get_particle_candidates_for_state( - solution.edge_quantum_numbers[int_edge_id], + solution.states[int_edge_id], allowed_particles, ) if len(particle_edges) == 0: logging.debug("Did not find any particle candidates for") logging.debug("edge id: %d", int_edge_id) logging.debug("edge properties:") - logging.debug(solution.edge_quantum_numbers[int_edge_id]) + logging.debug(solution.states[int_edge_id]) new_solutions_temp = [] for current_new_solution in current_new_solutions: for particle_edge in particle_edges: # a "shallow" copy of the nested dicts is needed new_edge_qns = { k: copy(v) - for k, v in current_new_solution.edge_quantum_numbers.items() + for k, v in current_new_solution.states.items() } new_edge_qns[int_edge_id].update(particle_edge) temp_solution = attrs.evolve( current_new_solution, - edge_quantum_numbers=new_edge_qns, + states=new_edge_qns, ) new_solutions_temp.append(temp_solution) current_new_solutions = new_solutions_temp @@ -444,8 +444,8 @@ def _create_variable_containers( return QNResult( [ QuantumNumberSolution( - edge_quantum_numbers=problem_set.initial_facts.states, - node_quantum_numbers=problem_set.initial_facts.interactions, + states=problem_set.initial_facts.states, + interactions=problem_set.initial_facts.interactions, ) ], ) @@ -548,8 +548,8 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: else: full_particle_solutions = [ QuantumNumberSolution( - node_quantum_numbers=problem_set.initial_facts.interactions, - edge_quantum_numbers=problem_set.initial_facts.states, + interactions=problem_set.initial_facts.interactions, + states=problem_set.initial_facts.states, ) ] @@ -560,8 +560,8 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: # and combine results result = QNResult() for full_particle_solution in full_particle_solutions: - interactions = full_particle_solution.node_quantum_numbers - states = full_particle_solution.edge_quantum_numbers + interactions = full_particle_solution.interactions + states = full_particle_solution.states interactions.update(problem_set.initial_facts.interactions) states.update(problem_set.initial_facts.states) result.extend( @@ -828,25 +828,19 @@ def __convert_solution_keys( """Convert keys of CSP solutions from `str` to quantum number types.""" converted_solutions = [] for solution in solutions: - edge_quantum_numbers: Dict[ - int, GraphEdgePropertyMap - ] = defaultdict(dict) - node_quantum_numbers: Dict[ - int, GraphNodePropertyMap - ] = defaultdict(dict) + states: Dict[int, GraphEdgePropertyMap] = defaultdict(dict) + interactions: Dict[int, GraphNodePropertyMap] = defaultdict(dict) for var_string, value in solution.items(): ele_id, qn_type = self.__var_string_to_data[var_string] if qn_type in getattr( # noqa: B009 EdgeQuantumNumber, "__args__" ): - edge_quantum_numbers[ele_id].update({qn_type: value}) # type: ignore[dict-item] + states[ele_id].update({qn_type: value}) # type: ignore[dict-item] else: - node_quantum_numbers[ele_id].update({qn_type: value}) # type: ignore[dict-item] + interactions[ele_id].update({qn_type: value}) # type: ignore[dict-item] converted_solutions.append( - QuantumNumberSolution( - node_quantum_numbers, edge_quantum_numbers - ) + QuantumNumberSolution(states, interactions) ) return converted_solutions diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 94c31fb3..4a1e5891 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -657,11 +657,11 @@ def __convert_result( topology=topology, interactions={ i: create_interaction_properties(x) - for i, x in solution.node_quantum_numbers.items() + for i, x in solution.interactions.items() }, states={ i: create_particle(x, self.__particles) - for i, x in solution.edge_quantum_numbers.items() + for i, x in solution.states.items() }, ) solutions.append(graph) From 6c2f1a26b750113b3f4ad3711ce0e2ab93f8d605 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:35 +0100 Subject: [PATCH 14/34] feat: define Transition Protocol and simplify _dot --- src/qrules/io/_dot.py | 205 +++++++++++++++----------------------- src/qrules/topology.py | 11 +- tests/unit/io/test_dot.py | 6 +- 3 files changed, 95 insertions(+), 127 deletions(-) diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 2a2c0d2e..a587e42b 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -14,6 +14,7 @@ List, Mapping, Optional, + Set, Tuple, Union, ) @@ -21,10 +22,15 @@ import attrs from qrules.combinatorics import InitialFacts -from qrules.particle import Particle, ParticleCollection, ParticleWithSpin +from qrules.particle import Particle, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction from qrules.solving import EdgeSettings, GraphSettings, NodeSettings -from qrules.topology import MutableTransition, Topology +from qrules.topology import ( + FrozenTransition, + MutableTransition, + Topology, + Transition, +) from qrules.transition import ProblemSet, StateTransition _DOT_HEAD = """digraph { @@ -138,7 +144,7 @@ def __create_graphviz_assignments(graphviz_attrs: Dict[str, Any]) -> List[str]: @embed_dot def graph_list_to_dot( - graphs: Iterable[MutableTransition], + graphs: Iterable[Transition], *, render_node: bool, render_final_state_id: bool, @@ -156,7 +162,7 @@ def graph_list_to_dot( raise ValueError( "Collapsed graphs cannot be rendered with node properties" ) - graphs = _collapse_graphs(graphs) + graphs = _collapse_graphs(graphs) # type: ignore[assignment] elif strip_spin: if render_node: stripped_graphs = [] @@ -166,9 +172,9 @@ def graph_list_to_dot( stripped_graph = _strip_projections(graph) if stripped_graph not in stripped_graphs: stripped_graphs.append(stripped_graph) - graphs = stripped_graphs + graphs = stripped_graphs # type: ignore[assignment] else: - graphs = _get_particle_graphs(graphs) + graphs = _get_particle_graphs(graphs) # type: ignore[assignment] dot = "" if not isinstance(graphs, abc.Sequence): graphs = list(graphs) @@ -188,7 +194,7 @@ def graph_list_to_dot( @embed_dot def graph_to_dot( - graph: MutableTransition, + graph: Transition, *, render_node: bool, render_final_state_id: bool, @@ -212,10 +218,10 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals graph: Union[ ProblemSet, StateTransition, - MutableTransition, Topology, - Tuple[Topology, InitialFacts], + Transition, Tuple[Topology, GraphSettings], + Tuple[Topology, InitialFacts], ], prefix: str = "", *, @@ -234,13 +240,13 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals InitialFacts, ProblemSet, StateTransition, - MutableTransition, Topology, + Transition, ] = graph[1] elif isinstance(graph, ProblemSet): rendered_graph = graph topology = graph.topology - elif isinstance(graph, (StateTransition, MutableTransition)): + elif isinstance(graph, (StateTransition, Transition)): rendered_graph = graph topology = graph.topology elif isinstance(graph, Topology): @@ -289,7 +295,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals label=node_label, graphviz_attrs=node_style, ) - if isinstance(graph, (StateTransition, MutableTransition)): + if isinstance(graph, (StateTransition, Transition)): if isinstance(graph, StateTransition): interactions: Mapping[ int, InteractionProperties @@ -337,8 +343,8 @@ def __get_edge_label( InitialFacts, ProblemSet, StateTransition, - MutableTransition, Topology, + Transition, ], edge_id: int, render_edge_id: bool, @@ -360,7 +366,7 @@ def __get_edge_label( return ___render_edge_with_id(edge_id, edge_property, render_edge_id) if isinstance(graph, StateTransition): graph = graph.to_graph() - if isinstance(graph, MutableTransition): + if isinstance(graph, Transition): edge_prop = graph.states[edge_id] return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) if isinstance(graph, Topology): @@ -375,7 +381,7 @@ def __get_edge_label( def ___render_edge_with_id( edge_id: int, edge_prop: Optional[ - Union[EdgeSettings, ParticleCollection, Particle, ParticleWithSpin] + Union[EdgeSettings, Iterable[Particle], Particle, ParticleWithSpin] ], render_edge_id: bool, ) -> str: @@ -391,19 +397,21 @@ def ___render_edge_with_id( def __render_edge_property( edge_prop: Optional[ - Union[EdgeSettings, ParticleCollection, Particle, ParticleWithSpin] + Union[EdgeSettings, Iterable[Particle], Particle, ParticleWithSpin] ] ) -> str: if isinstance(edge_prop, EdgeSettings): return __render_settings(edge_prop) if isinstance(edge_prop, Particle): return edge_prop.name - if isinstance(edge_prop, tuple): + if isinstance(edge_prop, abc.Iterable) and all( + map(lambda i: isinstance(i, Particle), edge_prop) + ): + return "\n".join(map(lambda p: p.name, edge_prop)) + if isinstance(edge_prop, tuple) and len(edge_prop) == 2: particle, spin_projection = edge_prop projection_label = _to_fraction(spin_projection, render_plus=True) return f"{particle.name}[{projection_label}]" - if isinstance(edge_prop, ParticleCollection): - return "\n".join(sorted(edge_prop.names)) raise NotImplementedError @@ -466,129 +474,80 @@ def __extract_priority(description: str) -> int: def _get_particle_graphs( - graphs: Iterable[ - MutableTransition[ParticleWithSpin, InteractionProperties] - ], -) -> List[MutableTransition[Particle, InteractionProperties]]: - """Strip `list` of `.MutableTransition` s of the spin projections. + graphs: Iterable[Transition[ParticleWithSpin, InteractionProperties]], +) -> List[FrozenTransition[Particle, None]]: + """Strip `list` of `.Transition` s of the spin projections. - Extract a `list` of `.MutableTransition` instances with only - `.Particle` instances on the edges. + Extract a `list` of `.Transition` instances with only `.Particle` instances + on the edges. .. seealso:: :doc:`/usage/visualize` """ - inventory: List[MutableTransition[Particle, InteractionProperties]] = [] + inventory = set() for transition in graphs: if isinstance(transition, StateTransition): transition = transition.to_graph() - if any( - transition.compare( - other, state_comparator=lambda e1, e2: e1[0] == e2 - ) - for other in inventory - ): - continue - stripped_graph = _strip_projections(transition) - inventory.append(stripped_graph) - inventory = sorted( + stripped_transition = _strip_projections(transition) + topology = stripped_transition.topology + particle_transition: FrozenTransition[ + Particle, None + ] = FrozenTransition( + stripped_transition.topology, + states=stripped_transition.states, + interactions={i: None for i in topology.nodes}, + ) + inventory.add(particle_transition) + return sorted( inventory, key=lambda g: [ g.states[i].mass for i in g.topology.intermediate_edge_ids ], ) - return inventory def _strip_projections( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], -) -> MutableTransition[Particle, InteractionProperties]: + graph: Transition[ParticleWithSpin, InteractionProperties], +) -> FrozenTransition[Particle, InteractionProperties]: if isinstance(graph, StateTransition): graph = graph.to_graph() - new_states = {} - for edge_id in graph.topology.edges: - states = graph.states[edge_id] - if states: - new_states[edge_id] = states[0] - new_interactions = {} - for node_id in graph.topology.nodes: - interactions = graph.interactions[node_id] - if interactions: - new_interactions[node_id] = attrs.evolve( - interactions, l_projection=None, s_projection=None - ) - return MutableTransition[Particle, InteractionProperties]( - topology=graph.topology, - interactions=new_interactions, - states=new_states, + return FrozenTransition( + graph.topology, + states={i: particle for i, (particle, _) in graph.states.items()}, + interactions={ + i: attrs.evolve(interaction, l_projection=None, s_projection=None) + for i, interaction in graph.interactions.items() + }, ) def _collapse_graphs( - graphs: Iterable[ - MutableTransition[ParticleWithSpin, InteractionProperties] - ], -) -> List[MutableTransition[ParticleCollection, InteractionProperties]]: - def merge_into( - graph: MutableTransition[Particle, InteractionProperties], - merged_graph: MutableTransition[ - ParticleCollection, InteractionProperties - ], - ) -> None: - if ( - graph.topology.intermediate_edge_ids - != merged_graph.topology.intermediate_edge_ids - ): - raise ValueError( - "Cannot merge graphs that don't have the same edge IDs" - ) - for i in graph.topology.edges: - particle = graph.states[i] - other_particles = merged_graph.states[i] - if particle not in other_particles: - other_particles += particle - - def is_same_shape( - graph: MutableTransition[Particle, InteractionProperties], - merged_graph: MutableTransition[ - ParticleCollection, InteractionProperties - ], - ) -> bool: - if graph.topology.edges != merged_graph.topology.edges: - return False - for edge_id in ( - graph.topology.incoming_edge_ids | graph.topology.outgoing_edge_ids - ): - edge_prop = merged_graph.states[edge_id] - if len(edge_prop) != 1: - return False - other_particle = next(iter(edge_prop)) - if other_particle != graph.states[edge_id]: - return False - return True - - particle_graphs = _get_particle_graphs(graphs) - inventory: List[ - MutableTransition[ParticleCollection, InteractionProperties] - ] = [] - for graph in particle_graphs: - append_to_inventory = True - for merged_graph in inventory: - if is_same_shape(graph, merged_graph): - merge_into(graph, merged_graph) - append_to_inventory = False - break - if append_to_inventory: - new_states = { - edge_id: ParticleCollection({graph.states[edge_id]}) - for edge_id in graph.topology.edges - } - inventory.append( - MutableTransition[ParticleCollection, InteractionProperties]( - topology=graph.topology, - interactions={ - i: graph.interactions[i] for i in graph.topology.nodes - }, - states=new_states, - ) + graphs: Iterable[Transition[ParticleWithSpin, InteractionProperties]], +) -> Tuple[FrozenTransition[Tuple[Particle, ...], None], ...]: + transition_groups = { + g.topology: MutableTransition[Set[Particle], None]( + g.topology, + states={i: set() for i in g.topology.edges}, + interactions={i: None for i in g.topology.nodes}, + ) + for g in graphs + } + for transition in graphs: + topology = transition.topology + group = transition_groups[topology] + for state_id, state in transition.states.items(): + particle, _ = state + group.states[state_id].add(particle) + particle_collection_graphs = [] + for topology in sorted(transition_groups): + group = transition_groups[topology] + particle_collection_graphs.append( + FrozenTransition[Tuple[Particle, ...], None]( + topology, + states={ + i: tuple(sorted(particles, key=lambda p: p.name)) + for i, particles in group.states.items() + }, + interactions=group.interactions, ) - return inventory + ) + return tuple(particle_collection_graphs) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 6c3b288b..da738f92 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -43,9 +43,9 @@ from qrules._implementers import implement_pretty_repr if sys.version_info >= (3, 8): - from typing import Protocol + from typing import Protocol, runtime_checkable else: - from typing_extensions import Protocol + from typing_extensions import Protocol, runtime_checkable if TYPE_CHECKING: try: @@ -655,6 +655,13 @@ def _attach_node_to_edges( """A `~typing.TypeVar` representing the type of node properties.""" +@runtime_checkable +class Transition(Protocol[EdgeType, NodeType]): + topology: Topology + states: Dict[int, EdgeType] + interactions: Dict[int, NodeType] + + @implement_pretty_repr @frozen(order=True) class FrozenTransition(Generic[EdgeType, NodeType]): diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index aca343ca..313500db 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -9,7 +9,7 @@ _get_particle_graphs, _strip_projections, ) -from qrules.particle import ParticleCollection +from qrules.particle import Particle, ParticleCollection from qrules.topology import ( Edge, Topology, @@ -210,13 +210,15 @@ def test_collapse_graphs( pdg = particle_database particle_graphs = _get_particle_graphs(reaction.to_graphs()) assert len(particle_graphs) == 2 + collapsed_graphs = _collapse_graphs(reaction.to_graphs()) assert len(collapsed_graphs) == 1 graph = next(iter(collapsed_graphs)) edge_id = next(iter(graph.topology.intermediate_edge_ids)) f_resonances = pdg.filter(lambda p: p.name in ["f(0)(980)", "f(0)(1500)"]) intermediate_states = graph.states[edge_id] - assert isinstance(intermediate_states, ParticleCollection) + assert isinstance(intermediate_states, tuple) + assert all(map(lambda i: isinstance(i, Particle), intermediate_states)) assert intermediate_states == f_resonances From 1e9c2c73b389af55c1523ee598df83b9656cbe22 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:35 +0100 Subject: [PATCH 15/34] fix: add default values for MutableTransition --- src/qrules/topology.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index da738f92..71d4ce2e 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -696,8 +696,10 @@ class MutableTransition(Generic[EdgeType, NodeType]): """ topology: Topology = field(validator=instance_of(Topology)) - states: Dict[int, EdgeType] = field(converter=_cast_states) - interactions: Dict[int, NodeType] = field(converter=_cast_interactions) + states: Dict[int, EdgeType] = field(converter=_cast_states, factory=dict) + interactions: Dict[int, NodeType] = field( + converter=_cast_interactions, factory=dict + ) def compare( self, From 17af25f7773de989c9a30f747c86db1837153659 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:35 +0100 Subject: [PATCH 16/34] refactor: remove GraphSettings, GraphElementProperties, QuantumNumberSolution, InitialFacts --- docs/usage/reaction.ipynb | 8 ++-- src/qrules/__init__.py | 1 + src/qrules/combinatorics.py | 14 ++----- src/qrules/io/_dot.py | 59 ++++++----------------------- src/qrules/solving.py | 74 ++++++++++++++++--------------------- src/qrules/transition.py | 7 ++-- 6 files changed, 54 insertions(+), 109 deletions(-) diff --git a/docs/usage/reaction.ipynb b/docs/usage/reaction.ipynb index 2b1b026e..7bf0d4fc 100644 --- a/docs/usage/reaction.ipynb +++ b/docs/usage/reaction.ipynb @@ -179,7 +179,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Create all {class}`.ProblemSet`'s using the boundary conditions of the {class}`.StateTransitionManager` instance. By default it uses the **isobar model** (tree of two-body decays) to build {class}`.Topology`'s. Various {class}`.InitialFacts` are created for each topology based on the initial and final state. Lastly some reasonable default settings for the solving process are chosen. Remember that each interaction node defines its own set of conservation laws." + "Create all {class}`.ProblemSet`'s using the boundary conditions of the {class}`.StateTransitionManager` instance. By default it uses the **isobar model** (tree of two-body decays) to build {class}`.Topology`'s. Various {obj}`.InitialFacts` are created for each topology based on the initial and final state. Lastly some reasonable default settings for the solving process are chosen. Remember that each interaction node defines its own set of conservation laws." ] }, { @@ -247,10 +247,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Each {class}`.ProblemSet` provides a mapping of {attr}`~.ProblemSet.initial_facts` that represent the initial and final states with spin projections. The nodes and edges in between these {attr}`~.ProblemSet.initial_facts` are still to be generated. This will be done from the provided {attr}`~.ProblemSet.solving_settings` ({class}`~.GraphSettings`). There are two mechanisms there:\n", + "Each {class}`.ProblemSet` provides a mapping of {attr}`~.ProblemSet.initial_facts` that represent the initial and final states with spin projections. The nodes and edges in between these {attr}`~.ProblemSet.initial_facts` are still to be generated. This will be done from the provided {attr}`~.ProblemSet.solving_settings`. There are two mechanisms there:\n", "\n", - "1. One the one hand, the {attr}`.EdgeSettings.qn_domains` and {attr}`.NodeSettings.qn_domains` contained in the {class}`~.GraphSettings` define the **domain** over which quantum number sets can be generated.\n", - "2. On the other, the {attr}`.EdgeSettings.rule_priorities` and {attr}`.NodeSettings.rule_priorities` in {class}`~.GraphSettings` define which **{mod}`.conservation_rules`** are used to determine which of the sets of generated quantum numbers are valid.\n", + "1. One the one hand, the {attr}`.EdgeSettings.qn_domains` and {attr}`.NodeSettings.qn_domains` contained in the {attr}`~.ProblemSet.solving_settings` define the **domain** over which quantum number sets can be generated.\n", + "2. On the other, the {attr}`.EdgeSettings.rule_priorities` and {attr}`.NodeSettings.rule_priorities` in {attr}`~.ProblemSet.solving_settings` define which **{mod}`.conservation_rules`** are used to determine which of the sets of generated quantum numbers are valid.\n", "\n", "Together, these two constraints allow the {class}`.StateTransitionManager` to generate a number of {class}`.MutableTransition`s that comply with the selected {mod}`.conservation_rules`." ] diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 478a3d02..4a56790c 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -140,6 +140,7 @@ def _check_violations( topology=topology, initial_facts=facts, solving_settings=GraphSettings( + facts.topology, interactions={ i: NodeSettings(conservation_rules=rules) for i, rules in node_rules.items() diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index e210e973..16fc4945 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -24,9 +24,6 @@ Union, ) -from attrs import field, frozen - -from qrules._implementers import implement_pretty_repr from qrules.particle import Particle, ParticleCollection from .particle import ParticleWithSpin @@ -35,13 +32,8 @@ StateWithSpins = Tuple[str, Sequence[float]] StateDefinition = Union[str, StateWithSpins] - - -@implement_pretty_repr -@frozen -class InitialFacts: - states: Dict[int, ParticleWithSpin] = field(factory=dict) - interactions: Dict[int, InteractionProperties] = field(factory=dict) +InitialFacts = MutableTransition[ParticleWithSpin, InteractionProperties] +"""A `.Transition` with only initial and final state information.""" class _KinematicRepresentation: @@ -268,7 +260,7 @@ def embed_in_list(some_list: List[Any]) -> List[List[Any]]: kinematic_permutation, particle_db ) edge_initial_facts.extend( - [InitialFacts(states=x) for x in spin_permutations] + [InitialFacts(topology, states=x) for x in spin_permutations] ) return edge_initial_facts diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index a587e42b..2b06f657 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -21,10 +21,9 @@ import attrs -from qrules.combinatorics import InitialFacts from qrules.particle import Particle, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction -from qrules.solving import EdgeSettings, GraphSettings, NodeSettings +from qrules.solving import EdgeSettings, NodeSettings from qrules.topology import ( FrozenTransition, MutableTransition, @@ -215,14 +214,7 @@ def graph_to_dot( def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals,too-many-statements - graph: Union[ - ProblemSet, - StateTransition, - Topology, - Transition, - Tuple[Topology, GraphSettings], - Tuple[Topology, InitialFacts], - ], + graph: Union[ProblemSet, StateTransition, Topology, Transition], prefix: str = "", *, render_node: bool, @@ -235,18 +227,8 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals dot = "" if isinstance(graph, tuple) and len(graph) == 2: topology: Topology = graph[0] - rendered_graph: Union[ - GraphSettings, - InitialFacts, - ProblemSet, - StateTransition, - Topology, - Transition, - ] = graph[1] - elif isinstance(graph, ProblemSet): - rendered_graph = graph - topology = graph.topology - elif isinstance(graph, (StateTransition, Transition)): + rendered_graph: Union[ProblemSet, Topology, Transition] = graph[1] + elif isinstance(graph, (ProblemSet, Transition)): rendered_graph = graph topology = graph.topology elif isinstance(graph, Topology): @@ -338,23 +320,14 @@ def __rank_string(node_edge_ids: Iterable[int], prefix: str = "") -> str: def __get_edge_label( - graph: Union[ - GraphSettings, - InitialFacts, - ProblemSet, - StateTransition, - Topology, - Transition, - ], + graph: Union[ProblemSet, Topology, Transition], edge_id: int, render_edge_id: bool, ) -> str: - if isinstance(graph, GraphSettings): - edge_setting = graph.states.get(edge_id) - return ___render_edge_with_id(edge_id, edge_setting, render_edge_id) - if isinstance(graph, InitialFacts): - initial_fact = graph.states.get(edge_id) - return ___render_edge_with_id(edge_id, initial_fact, render_edge_id) + if isinstance(graph, Topology): + if render_edge_id: + return str(edge_id) + return "" if isinstance(graph, ProblemSet): edge_setting = graph.solving_settings.states.get(edge_id) initial_fact = graph.initial_facts.states.get(edge_id) @@ -364,18 +337,8 @@ def __get_edge_label( if initial_fact: edge_property = initial_fact return ___render_edge_with_id(edge_id, edge_property, render_edge_id) - if isinstance(graph, StateTransition): - graph = graph.to_graph() - if isinstance(graph, Transition): - edge_prop = graph.states[edge_id] - return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) - if isinstance(graph, Topology): - if render_edge_id: - return str(edge_id) - return "" - raise NotImplementedError( - f"Cannot render {graph.__class__.__name__} as dot" - ) + edge_prop = graph.states.get(edge_id) + return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) def ___render_edge_with_id( diff --git a/src/qrules/solving.py b/src/qrules/solving.py index ee0cf70d..815c9142 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -2,10 +2,10 @@ """Functions to solve a particle reaction problem. This module is responsible for solving a particle reaction problem stated by a -`.MutableTransition` and corresponding `.GraphSettings`. The `.Solver` classes -(e.g. :class:`.CSPSolver`) generate new quantum numbers (for example belonging -to an intermediate state) and validate the decay processes with the rules -formulated by the :mod:`.conservation_rules` module. +`.QNProblemSet`. The `.Solver` classes (e.g. :class:`.CSPSolver`) generate new +quantum numbers (for example belonging to an intermediate state) and validate +the decay processes with the rules formulated by the :mod:`.conservation_rules` +module. """ @@ -55,7 +55,7 @@ EdgeQuantumNumbers, NodeQuantumNumber, ) -from .topology import Topology +from .topology import MutableTransition, Topology @implement_pretty_repr @@ -89,18 +89,10 @@ class NodeSettings: interaction_strength: float = 1.0 -@implement_pretty_repr -@define -class GraphSettings: - states: Dict[int, EdgeSettings] = field(factory=dict) - interactions: Dict[int, NodeSettings] = field(factory=dict) - - -@implement_pretty_repr -@define -class GraphElementProperties: - states: Dict[int, GraphEdgePropertyMap] = field(factory=dict) - interactions: Dict[int, GraphNodePropertyMap] = field(factory=dict) +GraphSettings = MutableTransition[EdgeSettings, NodeSettings] +GraphElementProperties = MutableTransition[ + GraphEdgePropertyMap, GraphNodePropertyMap +] @implement_pretty_repr @@ -109,25 +101,22 @@ class QNProblemSet: """Particle reaction problem set, defined as a graph like data structure. Args: - topology (`.Topology`): a topology that represent the structure of the - reaction - initial_facts (`.GraphElementProperties`): all of the known facts quantum - numbers of the problem - solving_settings (`.GraphSettings`): solving specific settings such as - the specific rules and variable domains for nodes and edges of the - topology + initial_facts: all of the known facts quantum numbers of the problem. + solving_settings: solving specific settings, such as the specific rules + and variable domains for nodes and edges of the :attr:`topology`. """ - topology: Topology initial_facts: GraphElementProperties solving_settings: GraphSettings + @property + def topology(self) -> Topology: + return self.initial_facts.topology + -@implement_pretty_repr -@frozen -class QuantumNumberSolution: - states: Dict[int, GraphEdgePropertyMap] - interactions: Dict[int, GraphNodePropertyMap] +QuantumNumberSolution = MutableTransition[ + GraphEdgePropertyMap, GraphNodePropertyMap +] def _convert_violated_rules_to_names( @@ -358,10 +347,9 @@ def _create_edge_variables( def _create_variable_containers( node_id: int, cons_law: Rule ) -> Tuple[List[dict], List[dict], dict]: - in_edges = problem_set.topology.get_edge_ids_ingoing_to_node(node_id) - out_edges = problem_set.topology.get_edge_ids_outgoing_from_node( - node_id - ) + topology = problem_set.topology + in_edges = topology.get_edge_ids_ingoing_to_node(node_id) + out_edges = topology.get_edge_ids_outgoing_from_node(node_id) edge_qns, node_qns = get_required_qns(cons_law) in_edges_vars = _create_edge_variables(in_edges, edge_qns) @@ -444,6 +432,7 @@ def _create_variable_containers( return QNResult( [ QuantumNumberSolution( + topology=problem_set.topology, states=problem_set.initial_facts.states, interactions=problem_set.initial_facts.interactions, ) @@ -534,7 +523,9 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: elif self.__scoresheet.rule_passes[(edge_id, rule)] == 0: edge_not_satisfied_rules[edge_id].add(rule) - solutions = self.__convert_solution_keys(solutions) + solutions = self.__convert_solution_keys( + problem_set.topology, solutions + ) # insert particle instances if self.__node_rules or self.__edge_rules: @@ -548,6 +539,7 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: else: full_particle_solutions = [ QuantumNumberSolution( + topology=problem_set.topology, interactions=problem_set.initial_facts.interactions, states=problem_set.initial_facts.states, ) @@ -558,6 +550,7 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: ): # rerun solver on these graphs using not executed rules # and combine results + topology = problem_set.topology result = QNResult() for full_particle_solution in full_particle_solutions: interactions = full_particle_solution.interactions @@ -567,12 +560,11 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: result.extend( validate_full_solution( QNProblemSet( - topology=problem_set.topology, initial_facts=GraphElementProperties( - interactions=interactions, - states=states, + topology, states, interactions ), solving_settings=GraphSettings( + topology, interactions={ i: NodeSettings(conservation_rules=rules) for i, rules in node_not_executed_rules.items() @@ -822,8 +814,7 @@ def __add_variable( self.__problem.addVariable(var_string, domain) def __convert_solution_keys( - self, - solutions: List[Dict[str, Scalar]], + self, topology: Topology, solutions: List[Dict[str, Scalar]] ) -> List[QuantumNumberSolution]: """Convert keys of CSP solutions from `str` to quantum number types.""" converted_solutions = [] @@ -840,9 +831,8 @@ def __convert_solution_keys( else: interactions[ele_id].update({qn_type: value}) # type: ignore[dict-item] converted_solutions.append( - QuantumNumberSolution(states, interactions) + QuantumNumberSolution(topology, states, interactions) ) - return converted_solutions diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 4a1e5891..9e4f0328 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -173,7 +173,7 @@ class ProblemSet: Args: topology: `.Topology` that contains the structure of the reaction. - initial_facts: `~.InitialFacts` that contain the info of initial and + initial_facts: `.InitialFacts` that contain the info of initial and final state in connection with the topology. solving_settings: Solving related settings such as the conservation rules and the quantum number domains. @@ -193,9 +193,8 @@ def to_qn_problem_set(self) -> QNProblemSet: for k, v in self.initial_facts.states.items() } return QNProblemSet( - topology=self.topology, initial_facts=GraphElementProperties( - interactions=interactions, states=states + self.topology, states, interactions ), solving_settings=self.solving_settings, ) @@ -464,11 +463,11 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: graph_settings: List[GraphSettings] = [ GraphSettings( + topology, states={ edge_id: create_edge_settings(edge_id) for edge_id in topology.edges }, - interactions={}, ) ] From a45c2f967d25d2bc6804b6532714f3474bb64124 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:36 +0100 Subject: [PATCH 17/34] feat: implement initial_states etc in FrozenTransition --- src/qrules/topology.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 71d4ce2e..66b4863f 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -675,6 +675,21 @@ def __attrs_post_init__(self) -> None: _assert_all_defined(self.topology.nodes, self.interactions) _assert_all_defined(self.topology.edges, self.states) + @property + def initial_states(self) -> Dict[int, EdgeType]: + return self.filter_states(self.topology.incoming_edge_ids) + + @property + def final_states(self) -> Dict[int, EdgeType]: + return self.filter_states(self.topology.outgoing_edge_ids) + + @property + def intermediate_states(self) -> Dict[int, EdgeType]: + return self.filter_states(self.topology.intermediate_edge_ids) + + def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, EdgeType]: + return {i: self.states[i] for i in edge_ids} + def _cast_states(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: return dict(obj) From fb68b72510bf79ec498331b4bb56e8d647d7f735 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:36 +0100 Subject: [PATCH 18/34] refactor: remove StateTransition class --- .flake8 | 3 + docs/conf.py | 5 +- docs/usage/conservation.ipynb | 10 +- docs/usage/visualize.ipynb | 2 +- src/qrules/io/__init__.py | 16 ++- src/qrules/io/_dict.py | 90 +++++++-------- src/qrules/io/_dot.py | 31 +++--- src/qrules/topology.py | 38 ++++++- src/qrules/transition.py | 82 +++----------- tests/channels/test_jpsi_to_gamma_pi0_pi0.py | 8 +- tests/unit/io/test_dot.py | 6 +- tests/unit/io/test_io.py | 7 +- tests/unit/test_transition.py | 109 +------------------ 13 files changed, 138 insertions(+), 269 deletions(-) diff --git a/.flake8 b/.flake8 index b138b130..051d7edf 100644 --- a/.flake8 +++ b/.flake8 @@ -32,6 +32,9 @@ ignore = W503 extend-select = TI100 +per-file-ignores = + # casts with generics + src/qrules/topology.py:E731 rst-roles = attr cite diff --git a/docs/conf.py b/docs/conf.py index bd356814..863348e4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -221,10 +221,9 @@ def fetch_logo(url: str, output_path: str) -> None: primary_domain = "py" nitpicky = True # warn if cross-references are missing nitpick_ignore = [ - ("py:class", "EdgeType"), + ("py:class", "qrules.topology.NewEdgeType"), + ("py:class", "qrules.topology.NewNodeType"), ("py:class", "NoneType"), - ("py:class", "MutableTransition"), - ("py:class", "ValueType"), ("py:class", "json.encoder.JSONEncoder"), ("py:class", "typing_extensions.Protocol"), ("py:obj", "qrules.topology._K"), diff --git a/docs/usage/conservation.ipynb b/docs/usage/conservation.ipynb index ec709090..d30d9898 100644 --- a/docs/usage/conservation.ipynb +++ b/docs/usage/conservation.ipynb @@ -215,14 +215,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Example with a {class}`.StateTransition`" + "## Example with a {obj}`.StateTransition`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "First, generate some {class}`.StateTransition`s with {func}`.generate_transitions`, then select one of them:" + "First, generate some {obj}`.StateTransition`s with {func}`.generate_transitions`, then select one of them:" ] }, { @@ -361,14 +361,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Modifying {class}`.StateTransition`s" + "### Modifying {obj}`.StateTransition`s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "When checking conservation rules, you may want to modify the properties on the {class}`.StateTransition`s. However, a {class}`.StateTransition` is frozen, so it is not possible to modify its {attr}`~.StateTransition.interactions` and {attr}`~.StateTransition.states`. The only way around this is to create a new instance with {func}`attrs.evolve`.\n", + "When checking conservation rules, you may want to modify the properties on the {obj}`.StateTransition`s. However, a {obj}`.StateTransition` is a {class}`.FrozenTransition`, so it is not possible to modify its {attr}`~.FrozenTransition.interactions` and {attr}`~.FrozenTransition.states`. The only way around this is to create a new instance with {func}`attrs.evolve`.\n", "\n", "First, we get the instance (in this case one of the {class}`.InteractionProperties`) and substitute its {attr}`.InteractionProperties.l_magnitude`:" ] @@ -387,7 +387,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We then again use {func}`attrs.evolve` to substitute the {attr}`.StateTransition.interactions` of the original {class}`.StateTransition`:" + "We then again use {func}`attrs.evolve` to substitute the {attr}`.Transition.interactions` of the original {obj}`.StateTransition`:" ] }, { diff --git a/docs/usage/visualize.ipynb b/docs/usage/visualize.ipynb index 061aa1ff..0dc90243 100644 --- a/docs/usage/visualize.ipynb +++ b/docs/usage/visualize.ipynb @@ -255,7 +255,7 @@ "tags": [] }, "source": [ - "## {class}`.StateTransition`s" + "## {obj}`.StateTransition`s" ] }, { diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index 4f46a4fb..f9fb6d9b 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -15,8 +15,8 @@ import yaml from qrules.particle import Particle, ParticleCollection -from qrules.topology import MutableTransition, Topology -from qrules.transition import ProblemSet, ReactionInfo, State, StateTransition +from qrules.topology import Topology, Transition +from qrules.transition import ProblemSet, ReactionInfo, State from . import _dict, _dot @@ -27,15 +27,15 @@ def asdict(instance: object) -> dict: return _dict.from_particle(instance) if isinstance(instance, ParticleCollection): return _dict.from_particle_collection(instance) - if isinstance(instance, (ReactionInfo, State, StateTransition)): + if isinstance(instance, (ReactionInfo, State)): return attrs.asdict( instance, recurse=True, filter=lambda a, _: a.init, value_serializer=_dict._value_serializer, ) - if isinstance(instance, MutableTransition): - return _dict.from_stg(instance) + if isinstance(instance, Transition): + return _dict.from_transition(instance) if isinstance(instance, Topology): return _dict.from_topology(instance) raise NotImplementedError( @@ -53,7 +53,7 @@ def fromdict(definition: dict) -> object: if keys == {"transitions", "formalism"}: return _dict.build_reaction_info(definition) if keys == {"topology", "states", "interactions"}: - return _dict.build_state_transition(definition) + return _dict.build_transition(definition) if keys == __REQUIRED_TOPOLOGY_FIELDS: return _dict.build_topology(definition) raise NotImplementedError(f"Could not determine type from keys {keys}") @@ -127,9 +127,7 @@ def asdot( edge_style = {} if node_style is None: node_style = {} - if isinstance(instance, StateTransition): - instance = instance.to_graph() - if isinstance(instance, (ProblemSet, MutableTransition, Topology)): + if isinstance(instance, (ProblemSet, Topology, Transition)): dot = _dot.graph_to_dot( instance, render_node=render_node, diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index 7d06674b..60a2373b 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -8,16 +8,10 @@ import attrs -from qrules.particle import ( - Parity, - Particle, - ParticleCollection, - ParticleWithSpin, - Spin, -) +from qrules.particle import Parity, Particle, ParticleCollection, Spin from qrules.quantum_numbers import InteractionProperties -from qrules.topology import Edge, MutableTransition, Topology -from qrules.transition import ReactionInfo, State, StateTransition +from qrules.topology import Edge, FrozenTransition, Topology, Transition +from qrules.transition import ReactionInfo, State def from_particle_collection(particles: ParticleCollection) -> dict: @@ -33,30 +27,13 @@ def from_particle(particle: Particle) -> dict: ) -def from_stg( - graph: MutableTransition[ParticleWithSpin, InteractionProperties] -) -> dict: - topology = graph.topology - states_def = {} - for i in topology.edges: - particle, spin_projection = graph.states[i] - if isinstance(spin_projection, float) and spin_projection.is_integer(): - spin_projection = int(spin_projection) - states_def[i] = { - "particle": from_particle(particle), - "spin_projection": spin_projection, - } - interactions_def = {} - for i in topology.nodes: - node_prop = graph.interactions[i] - interactions_def[i] = attrs.asdict( - node_prop, filter=lambda a, v: a.init and a.default != v - ) - return { - "topology": from_topology(topology), - "states": states_def, - "interactions": interactions_def, - } +def from_transition(graph: Transition) -> dict: + return attrs.asdict( + graph, + recurse=True, + value_serializer=_value_serializer, + filter=lambda attribute, value: attribute.default != value, + ) def from_topology(topology: Topology) -> dict: @@ -75,10 +52,9 @@ def _value_serializer( # pylint: disable=unused-argument if all(map(lambda p: isinstance(p, Particle), value.values())): return {k: v.name for k, v in value.items()} return dict(value) - if not isinstance( - inst, (ReactionInfo, State, StateTransition) - ) and isinstance(value, Particle): - return value.name + if not isinstance(inst, (ReactionInfo, State, FrozenTransition)): + if isinstance(value, Particle): + return value.name if isinstance(value, Parity): return {"value": value.value} if isinstance(value, Spin): @@ -112,30 +88,38 @@ def build_particle(definition: dict) -> Particle: def build_reaction_info(definition: dict) -> ReactionInfo: transitions = [ - build_state_transition(transition_def) + build_transition(transition_def) for transition_def in definition["transitions"] ] return ReactionInfo(transitions, formalism=definition["formalism"]) -def build_state_transition(definition: dict) -> StateTransition: +def build_transition( + definition: dict, +) -> FrozenTransition[State, InteractionProperties]: topology = build_topology(definition["topology"]) - states = { - int(i): State( - particle=build_particle(state_def["particle"]), - spin_projection=float(state_def["spin_projection"]), - ) - for i, state_def in definition["states"].items() - } + states_def: Dict[int, dict] = definition["states"] + states: Dict[int, State] = {} + for i, edge_def in states_def.items(): + states[int(i)] = build_state(edge_def) + interactions_def: Dict[int, dict] = definition["interactions"] interactions = { - int(i): InteractionProperties(**interaction_def) - for i, interaction_def in definition["interactions"].items() + int(i): InteractionProperties(**node_def) + for i, node_def in interactions_def.items() } - return StateTransition( - topology=topology, - states=states, - interactions=interactions, - ) + return FrozenTransition(topology, states, interactions) + + +def build_state(definition: Any) -> State: + if isinstance(definition, (list, tuple)) and len(definition) == 2: + particle = build_particle(definition[0]) + spin_projection = float(definition[1]) + return State(particle, spin_projection) + if isinstance(definition, dict): + particle = build_particle(definition["particle"]) + spin_projection = float(definition["spin_projection"]) + return State(particle, spin_projection) + raise NotImplementedError() def build_topology(definition: dict) -> Topology: diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 2b06f657..ee605ec8 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -12,7 +12,6 @@ Dict, Iterable, List, - Mapping, Optional, Set, Tuple, @@ -30,7 +29,7 @@ Topology, Transition, ) -from qrules.transition import ProblemSet, StateTransition +from qrules.transition import ProblemSet, State, StateTransition _DOT_HEAD = """digraph { rankdir=LR; @@ -166,8 +165,10 @@ def graph_list_to_dot( if render_node: stripped_graphs = [] for graph in graphs: - if isinstance(graph, StateTransition): - graph = graph.to_graph() + if isinstance(graph, FrozenTransition): + graph = graph.convert( + lambda s: (s.particle, s.spin_projection) + ) stripped_graph = _strip_projections(graph) if stripped_graph not in stripped_graphs: stripped_graphs.append(stripped_graph) @@ -277,14 +278,8 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals label=node_label, graphviz_attrs=node_style, ) - if isinstance(graph, (StateTransition, Transition)): - if isinstance(graph, StateTransition): - interactions: Mapping[ - int, InteractionProperties - ] = graph.interactions - else: - interactions = {i: graph.interactions[i] for i in topology.nodes} - for node_id, node_prop in interactions.items(): + if isinstance(graph, Transition): + for node_id, node_prop in graph.interactions.items(): node_label = "" if render_node: node_label = __node_label(node_prop) @@ -371,6 +366,8 @@ def __render_edge_property( map(lambda i: isinstance(i, Particle), edge_prop) ): return "\n".join(map(lambda p: p.name, edge_prop)) + if isinstance(edge_prop, State): + edge_prop = edge_prop.particle, edge_prop.spin_projection if isinstance(edge_prop, tuple) and len(edge_prop) == 2: particle, spin_projection = edge_prop projection_label = _to_fraction(spin_projection, render_plus=True) @@ -448,8 +445,10 @@ def _get_particle_graphs( """ inventory = set() for transition in graphs: - if isinstance(transition, StateTransition): - transition = transition.to_graph() + if isinstance(transition, FrozenTransition): + transition = transition.convert( + lambda s: (s.particle, s.spin_projection) + ) stripped_transition = _strip_projections(transition) topology = stripped_transition.topology particle_transition: FrozenTransition[ @@ -471,8 +470,8 @@ def _get_particle_graphs( def _strip_projections( graph: Transition[ParticleWithSpin, InteractionProperties], ) -> FrozenTransition[Particle, InteractionProperties]: - if isinstance(graph, StateTransition): - graph = graph.to_graph() + if isinstance(graph, FrozenTransition): + graph = graph.convert(lambda s: (s.particle, s.spin_projection)) return FrozenTransition( graph.topology, states={i: particle for i, (particle, _) in graph.states.items()}, diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 66b4863f..bada2177 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -34,6 +34,7 @@ Tuple, TypeVar, ValuesView, + cast, ) import attrs @@ -658,8 +659,12 @@ def _attach_node_to_edges( @runtime_checkable class Transition(Protocol[EdgeType, NodeType]): topology: Topology - states: Dict[int, EdgeType] - interactions: Dict[int, NodeType] + states: Mapping[int, EdgeType] + interactions: Mapping[int, NodeType] + + +NewEdgeType = TypeVar("NewEdgeType") +NewNodeType = TypeVar("NewNodeType") @implement_pretty_repr @@ -690,6 +695,32 @@ def intermediate_states(self) -> Dict[int, EdgeType]: def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, EdgeType]: return {i: self.states[i] for i in edge_ids} + def convert( + self, + state_converter: Optional[Callable[[EdgeType], NewEdgeType]] = None, + interaction_converter: Optional[ + Callable[[NodeType], NewNodeType] + ] = None, + ) -> "FrozenTransition[NewEdgeType, NodeType]": + # pylint: disable=unnecessary-lambda + if state_converter is None: + state_converter = lambda _: cast(NewEdgeType, _) + if interaction_converter is None: + interaction_converter = lambda _: cast(NewNodeType, _) + return FrozenTransition[NewEdgeType, NodeType]( + self.topology, + states={ + i: state_converter(state) for i, state in self.states.items() + }, + interactions={ + i: interaction_converter(interaction) + for i, interaction in self.interactions.items() + }, + ) + + def unfreeze(self) -> "MutableTransition[EdgeType, NodeType]": + return MutableTransition(self.topology, self.states, self.interactions) + def _cast_states(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: return dict(obj) @@ -753,6 +784,9 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: if value2 is not None: self.states[edge_id1] = value2 + def freeze(self) -> FrozenTransition[EdgeType, NodeType]: + return FrozenTransition(self.topology, self.states, self.interactions) + def _assert_all_defined(items: Iterable, properties: Iterable) -> None: existing = set(items) diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 9e4f0328..ece03e56 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -72,9 +72,9 @@ ) from .topology import ( FrozenDict, + FrozenTransition, MutableTransition, Topology, - _assert_all_defined, create_isobar_topologies, create_n_body_topology, ) @@ -728,74 +728,13 @@ class State: spin_projection: float = field(converter=_to_float) -@implement_pretty_repr -@frozen(order=True) -class StateTransition: - """Frozen instance of a `.MutableTransition` of a particle with spin.""" - - topology: Topology = field(validator=instance_of(Topology)) - states: FrozenDict[int, State] = field(converter=FrozenDict) - interactions: FrozenDict[int, InteractionProperties] = field( - converter=FrozenDict - ) - - def __attrs_post_init__(self) -> None: - _assert_all_defined(self.topology.edges, self.states) - _assert_all_defined(self.topology.nodes, self.interactions) - - @staticmethod - def from_graph( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], - ) -> "StateTransition": - return StateTransition( - topology=graph.topology, - states=FrozenDict( - {i: State(*graph.states[i]) for i in graph.topology.edges} - ), - interactions=FrozenDict( - {i: graph.interactions[i] for i in graph.topology.nodes} - ), - ) - - def to_graph( - self, - ) -> MutableTransition[ParticleWithSpin, InteractionProperties]: - return MutableTransition[ParticleWithSpin, InteractionProperties]( - topology=self.topology, - states={ - i: (state.particle, state.spin_projection) - for i, state in self.states.items() - }, - interactions=self.interactions, - ) - - @property - def initial_states(self) -> Dict[int, State]: - return self.filter_states(self.topology.incoming_edge_ids) - - @property - def final_states(self) -> Dict[int, State]: - return self.filter_states(self.topology.outgoing_edge_ids) - - @property - def intermediate_states(self) -> Dict[int, State]: - return self.filter_states(self.topology.intermediate_edge_ids) - - def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, State]: - return {i: self.states[i] for i in edge_ids} - - @property - def particles(self) -> Dict[int, Particle]: - return {i: edge_prop.particle for i, edge_prop in self.states.items()} +StateTransition = FrozenTransition[State, InteractionProperties] +"""Transition of some initial `.State` to a final `.State`.""" def _sort_tuple( iterable: Iterable[StateTransition], ) -> Tuple[StateTransition, ...]: - if not all(map(lambda t: isinstance(t, StateTransition), iterable)): - raise TypeError( - f"Not all instances are of type {StateTransition.__name__}" - ) return tuple(sorted(iterable)) @@ -837,13 +776,24 @@ def from_graphs( ], formalism: str, ) -> "ReactionInfo": - transitions = [StateTransition.from_graph(g) for g in graphs] + transitions = [ + g.freeze().convert(state_converter=lambda state: State(*state)) + for g in graphs + ] return ReactionInfo(transitions, formalism) def to_graphs( self, ) -> List[MutableTransition[ParticleWithSpin, InteractionProperties]]: - return [transition.to_graph() for transition in self.transitions] + return [ + transition.convert( + state_converter=lambda state: ( + state.particle, + state.spin_projection, + ) + ).unfreeze() + for transition in self.transitions + ] def group_by_topology(self) -> Dict[Topology, List[StateTransition]]: groupings = defaultdict(list) diff --git a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py index 54a8cbec..ef3726d9 100644 --- a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py +++ b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py @@ -60,7 +60,9 @@ def test_id_to_particle_mappings(particle_database): assert len(reaction.transitions) == 4 iter_transitions = iter(reaction.transitions) first_transition = next(iter_transitions) - graph = first_transition.to_graph() + graph = first_transition.convert( + lambda s: (s.particle, s.spin_projection) + ).unfreeze() ref_mapping_fs = _create_edge_id_particle_mapping( graph, graph.topology.outgoing_edge_ids ) @@ -68,7 +70,9 @@ def test_id_to_particle_mappings(particle_database): graph, graph.topology.incoming_edge_ids ) for transition in iter_transitions: - graph = transition.to_graph() + graph = transition.convert( + lambda s: (s.particle, s.spin_projection) + ).unfreeze() assert ref_mapping_fs == _create_edge_id_particle_mapping( graph, graph.topology.outgoing_edge_ids ) diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index 313500db..a47d72d5 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -208,10 +208,10 @@ def test_collapse_graphs( particle_database: ParticleCollection, ): pdg = particle_database - particle_graphs = _get_particle_graphs(reaction.to_graphs()) + particle_graphs = _get_particle_graphs(reaction.to_graphs()) # type: ignore[arg-type] assert len(particle_graphs) == 2 - collapsed_graphs = _collapse_graphs(reaction.to_graphs()) + collapsed_graphs = _collapse_graphs(reaction.to_graphs()) # type: ignore[arg-type] assert len(collapsed_graphs) == 1 graph = next(iter(collapsed_graphs)) edge_id = next(iter(graph.topology.intermediate_edge_ids)) @@ -226,7 +226,7 @@ def test_get_particle_graphs( reaction: ReactionInfo, particle_database: ParticleCollection ): pdg = particle_database - graphs = _get_particle_graphs(reaction.to_graphs()) + graphs = _get_particle_graphs(reaction.to_graphs()) # type: ignore[arg-type] assert len(graphs) == 2 assert graphs[0].states[3] == pdg["f(0)(980)"] assert graphs[1].states[3] == pdg["f(0)(1500)"] diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index f34b8b6e..58e0c282 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -5,11 +5,12 @@ from qrules import io from qrules.particle import Particle, ParticleCollection from qrules.topology import ( + FrozenTransition, Topology, create_isobar_topologies, create_n_body_topology, ) -from qrules.transition import ReactionInfo, StateTransition +from qrules.transition import ReactionInfo, State def through_dict(instance): @@ -46,8 +47,8 @@ def test_asdict_fromdict_reaction(reaction: ReactionInfo): # MutableTransition for graph in reaction.to_graphs(): fromdict = through_dict(graph) - assert isinstance(fromdict, StateTransition) - assert graph == fromdict.to_graph() + assert isinstance(fromdict, FrozenTransition) + assert graph.freeze().convert(lambda s: State(*s)) == fromdict # ReactionInfo fromdict = through_dict(reaction) assert isinstance(fromdict, ReactionInfo) diff --git a/tests/unit/test_transition.py b/tests/unit/test_transition.py index 55275b92..2a617ef8 100644 --- a/tests/unit/test_transition.py +++ b/tests/unit/test_transition.py @@ -1,8 +1,5 @@ # pyright: reportUnusedImport=false # pylint: disable=eval-used, no-self-use -from operator import itemgetter -from typing import List - import pytest from IPython.lib.pretty import pretty @@ -10,22 +7,17 @@ Parity, Particle, ParticleCollection, - ParticleWithSpin, Spin, ) from qrules.quantum_numbers import InteractionProperties # noqa: F401 from qrules.topology import ( # noqa: F401 Edge, FrozenDict, + FrozenTransition, MutableTransition, Topology, ) -from qrules.transition import State # noqa: F401 -from qrules.transition import ( - ReactionInfo, - StateTransition, - StateTransitionManager, -) +from qrules.transition import ReactionInfo, State, StateTransitionManager class TestReactionInfo: @@ -40,7 +32,7 @@ def test_properties(self, reaction: ReactionInfo): else: assert len(reaction.transitions) == 8 for transition in reaction.transitions: - assert isinstance(transition, StateTransition) + assert isinstance(transition, FrozenTransition) @pytest.mark.parametrize("repr_method", [repr, pretty]) def test_repr(self, repr_method, reaction: ReactionInfo): @@ -81,101 +73,6 @@ def create_state(state_def) -> State: assert state2 >= state1 -class TestStateTransition: - def test_ordering(self, reaction: ReactionInfo): - sorted_transitions: List[StateTransition] = sorted( - reaction.transitions - ) - if reaction.formalism.startswith("cano"): - first = sorted_transitions[0] - second = sorted_transitions[1] - assert first.interactions[0].l_magnitude == 0.0 - assert second.interactions[0].l_magnitude == 2.0 - assert first.interactions[1] == second.interactions[1] - transition_selection = sorted_transitions[::2] - else: - transition_selection = sorted_transitions - - simplified_rendering = [ - tuple( - ( - transition.states[state_id].particle.name, - int(transition.states[state_id].spin_projection), - ) - for state_id in sorted(transition.states) - ) - for transition in transition_selection - ] - - assert simplified_rendering[:3] == [ - ( - ("J/psi(1S)", -1), - ("gamma", -1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(980)", 0), - ), - ( - ("J/psi(1S)", -1), - ("gamma", -1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(1500)", 0), - ), - ( - ("J/psi(1S)", -1), - ("gamma", +1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(980)", 0), - ), - ] - assert simplified_rendering[-1] == ( - ("J/psi(1S)", +1), - ("gamma", +1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(1500)", 0), - ) - - # J/psi - first_half = slice(0, int(len(simplified_rendering) / 2)) - for item in simplified_rendering[first_half]: - assert item[0] == ("J/psi(1S)", -1) - second_half = slice(int(len(simplified_rendering) / 2), None) - for item in simplified_rendering[second_half]: - assert item[0] == ("J/psi(1S)", +1) - second_half = slice(int(len(simplified_rendering) / 2), None) - # gamma - for item in itemgetter(0, 1, 4, 5)(simplified_rendering): - assert item[1] == ("gamma", -1) - for item in itemgetter(2, 3, 6, 7)(simplified_rendering): - assert item[1] == ("gamma", +1) - # pi0 - for item in simplified_rendering: - assert item[2] == ("pi0", 0) - assert item[3] == ("pi0", 0) - # f0 - for item in simplified_rendering[::2]: - assert item[4] == ("f(0)(980)", 0) - for item in simplified_rendering[1::2]: - assert item[4] == ("f(0)(1500)", 0) - - @pytest.mark.parametrize("repr_method", [repr, pretty]) - def test_repr(self, repr_method, reaction: ReactionInfo): - for instance in reaction.transitions: - from_repr = eval(repr_method(instance)) - assert from_repr == instance - - def test_from_to_graph(self, reaction: ReactionInfo): - assert len(reaction.group_by_topology()) == 1 - assert len(reaction.transitions) in {8, 16} - for transition in reaction.transitions: - graph = transition.to_graph() - from_graph = StateTransition.from_graph(graph) - assert transition == from_graph - - class TestStateTransitionManager: def test_allowed_intermediate_particles(self): stm = StateTransitionManager( From f0cd48d5ef16b35422af3eab7cd20f5a2598ddcb Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:37 +0100 Subject: [PATCH 19/34] refactor: remove ReactionInfo.from/to_graphs() --- src/qrules/io/__init__.py | 2 +- src/qrules/io/_dot.py | 35 ++++++++++++++++++++++++----------- src/qrules/transition.py | 32 +++++--------------------------- tests/unit/io/test_dot.py | 6 +++--- tests/unit/io/test_io.py | 8 ++++---- tests/unit/test_transition.py | 11 +++-------- 6 files changed, 40 insertions(+), 54 deletions(-) diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index f9fb6d9b..aa45d5a0 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -139,7 +139,7 @@ def asdot( ) return _dot.insert_graphviz_styling(dot, graphviz_attrs=figure_style) if isinstance(instance, ReactionInfo): - instance = instance.to_graphs() + instance = instance.transitions if isinstance(instance, abc.Iterable): dot = _dot.graph_list_to_dot( instance, diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index ee605ec8..58f99d2a 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -16,6 +16,7 @@ Set, Tuple, Union, + cast, ) import attrs @@ -468,17 +469,26 @@ def _get_particle_graphs( def _strip_projections( - graph: Transition[ParticleWithSpin, InteractionProperties], + graph: Transition[Any, InteractionProperties], ) -> FrozenTransition[Particle, InteractionProperties]: - if isinstance(graph, FrozenTransition): - graph = graph.convert(lambda s: (s.particle, s.spin_projection)) - return FrozenTransition( - graph.topology, - states={i: particle for i, (particle, _) in graph.states.items()}, - interactions={ - i: attrs.evolve(interaction, l_projection=None, s_projection=None) - for i, interaction in graph.interactions.items() - }, + if isinstance(graph, MutableTransition): + transition = graph.freeze() + transition = cast(FrozenTransition[Any, InteractionProperties], graph) + return transition.convert( + state_converter=__to_particle, + interaction_converter=lambda i: attrs.evolve( + i, l_projection=None, s_projection=None + ), + ) + + +def __to_particle(state: Any) -> Particle: + if isinstance(state, State): + return state.particle + if isinstance(state, tuple) and len(state) == 2: + return state[0] + raise NotImplementedError( + f"Cannot extract a particle from type {type(state).__name__}" ) @@ -497,7 +507,10 @@ def _collapse_graphs( topology = transition.topology group = transition_groups[topology] for state_id, state in transition.states.items(): - particle, _ = state + if isinstance(state, State): + particle = state.particle + else: + particle, _ = state group.states[state_id].add(particle) particle_collection_graphs = [] for topology in sorted(transition_groups): diff --git a/src/qrules/transition.py b/src/qrules/transition.py index ece03e56..56bfae43 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -633,7 +633,11 @@ def find_solutions( # pylint: disable=too-many-branches _match_final_state_ids(graph, self.final_state) for graph in final_solutions ] - return ReactionInfo.from_graphs(final_solutions, self.formalism) + transitions = [ + graph.freeze().convert(lambda s: State(*s)) + for graph in final_solutions + ] + return ReactionInfo(transitions, self.formalism) def _solve( self, qn_problem_set: QNProblemSet @@ -769,32 +773,6 @@ def get_intermediate_particles(self) -> ParticleCollection: } return ParticleCollection(particles) - @staticmethod - def from_graphs( - graphs: Iterable[ - MutableTransition[ParticleWithSpin, InteractionProperties] - ], - formalism: str, - ) -> "ReactionInfo": - transitions = [ - g.freeze().convert(state_converter=lambda state: State(*state)) - for g in graphs - ] - return ReactionInfo(transitions, formalism) - - def to_graphs( - self, - ) -> List[MutableTransition[ParticleWithSpin, InteractionProperties]]: - return [ - transition.convert( - state_converter=lambda state: ( - state.particle, - state.spin_projection, - ) - ).unfreeze() - for transition in self.transitions - ] - def group_by_topology(self) -> Dict[Topology, List[StateTransition]]: groupings = defaultdict(list) for transition in self.transitions: diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index a47d72d5..040824e9 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -208,10 +208,10 @@ def test_collapse_graphs( particle_database: ParticleCollection, ): pdg = particle_database - particle_graphs = _get_particle_graphs(reaction.to_graphs()) # type: ignore[arg-type] + particle_graphs = _get_particle_graphs(reaction.transitions) # type: ignore[arg-type] assert len(particle_graphs) == 2 - collapsed_graphs = _collapse_graphs(reaction.to_graphs()) # type: ignore[arg-type] + collapsed_graphs = _collapse_graphs(reaction.transitions) # type: ignore[arg-type] assert len(collapsed_graphs) == 1 graph = next(iter(collapsed_graphs)) edge_id = next(iter(graph.topology.intermediate_edge_ids)) @@ -226,7 +226,7 @@ def test_get_particle_graphs( reaction: ReactionInfo, particle_database: ParticleCollection ): pdg = particle_database - graphs = _get_particle_graphs(reaction.to_graphs()) # type: ignore[arg-type] + graphs = _get_particle_graphs(reaction.transitions) # type: ignore[arg-type] assert len(graphs) == 2 assert graphs[0].states[3] == pdg["f(0)(980)"] assert graphs[1].states[3] == pdg["f(0)(1500)"] diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index 58e0c282..8bc97ec5 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -10,7 +10,7 @@ create_isobar_topologies, create_n_body_topology, ) -from qrules.transition import ReactionInfo, State +from qrules.transition import ReactionInfo def through_dict(instance): @@ -44,11 +44,11 @@ def test_asdict_fromdict(particle_selection: ParticleCollection): def test_asdict_fromdict_reaction(reaction: ReactionInfo): - # MutableTransition - for graph in reaction.to_graphs(): + # FrozenTransition + for graph in reaction.transitions: fromdict = through_dict(graph) assert isinstance(fromdict, FrozenTransition) - assert graph.freeze().convert(lambda s: State(*s)) == fromdict + assert graph == fromdict # ReactionInfo fromdict = through_dict(reaction) assert isinstance(fromdict, ReactionInfo) diff --git a/tests/unit/test_transition.py b/tests/unit/test_transition.py index 2a617ef8..435e4c4d 100644 --- a/tests/unit/test_transition.py +++ b/tests/unit/test_transition.py @@ -1,5 +1,7 @@ # pyright: reportUnusedImport=false # pylint: disable=eval-used, no-self-use +from copy import deepcopy + import pytest from IPython.lib.pretty import pretty @@ -40,15 +42,8 @@ def test_repr(self, repr_method, reaction: ReactionInfo): from_repr = eval(repr_method(instance)) assert from_repr == instance - def test_from_to_graphs(self, reaction: ReactionInfo): - graphs = reaction.to_graphs() - from_graphs = ReactionInfo.from_graphs(graphs, reaction.formalism) - assert from_graphs == reaction - def test_hash(self, reaction: ReactionInfo): - graphs = reaction.to_graphs() - from_graphs = ReactionInfo.from_graphs(graphs, reaction.formalism) - assert hash(from_graphs) == hash(reaction) + assert hash(deepcopy(reaction)) == hash(reaction) class TestState: From 4e3281835e38c0be33350ca190565165f6935592 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:00:37 +0100 Subject: [PATCH 20/34] docs: improve API rendering with autodoc_type_aliases --- .pylintrc | 1 + docs/_relink_references.py | 104 ++++++++++++++++++++++++++++++------- docs/conf.py | 6 +++ src/qrules/__init__.py | 2 +- src/qrules/solving.py | 6 ++- src/qrules/transition.py | 19 +++---- 6 files changed, 104 insertions(+), 34 deletions(-) diff --git a/.pylintrc b/.pylintrc index fa0c26f6..4d1ca265 100644 --- a/.pylintrc +++ b/.pylintrc @@ -22,6 +22,7 @@ disable= missing-function-docstring, # pydocstyle missing-module-docstring, # pydocstyle no-member, # conflicts with attrs.field + no-name-in-module, # already checked by mypy not-an-iterable, # conflicts with attrs.field not-callable, # conflicts with attrs.field redefined-builtin, # flake8-built diff --git a/docs/_relink_references.py b/docs/_relink_references.py index 65cb1ac0..261a572b 100644 --- a/docs/_relink_references.py +++ b/docs/_relink_references.py @@ -8,13 +8,13 @@ See also https://github.com/sphinx-doc/sphinx/issues/5868. """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List if TYPE_CHECKING: + from docutils import nodes from sphinx.addnodes import pending_xref from sphinx.environment import BuildEnvironment - __TARGET_SUBSTITUTIONS = { "a set-like object providing a view on D's items": "typing.ItemsView", "a set-like object providing a view on D's keys": "typing.KeysView", @@ -23,6 +23,38 @@ } __REF_TYPE_SUBSTITUTIONS = { "None": "obj", + "qrules.combinatorics.InitialFacts": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.baryon_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.bottomness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.c_parity": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.charge": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.charmness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.electron_lepton_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.g_parity": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.isospin_magnitude": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.isospin_projection": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.mass": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.muon_lepton_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.parity": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.pid": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.spin_magnitude": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.spin_projection": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.strangeness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.tau_lepton_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.topness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.width": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.l_magnitude": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.l_projection": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.parity_prefactor": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.s_magnitude": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.s_projection": "obj", + "qrules.solving.GraphElementProperties": "obj", + "qrules.solving.GraphSettings": "obj", + "qrules.topology.EdgeType": "obj", + "qrules.topology.KeyType": "obj", + "qrules.topology.NodeType": "obj", + "qrules.topology.ValueType": "obj", + "qrules.transition.StateTransition": "obj", } @@ -31,34 +63,68 @@ def _new_type_to_xref( env: "BuildEnvironment" = None, suppress_prefix: bool = False, ) -> "pending_xref": - if env: - kwargs = { - "py:module": env.ref_context.get("py:module"), - "py:class": env.ref_context.get("py:class"), - } - else: - kwargs = {} + import sphinx + from sphinx.addnodes import pending_xref - target = __TARGET_SUBSTITUTIONS.get(target, target) - reftype = __REF_TYPE_SUBSTITUTIONS.get(target, "class") - if suppress_prefix: - short_text = target.split(".")[-1] - else: - short_text = target + if sphinx.version_info >= (4, 4): + # https://github.com/sphinx-doc/sphinx/blob/v4.4.0/sphinx/domains/python.py#L110-L133 + from sphinx.domains.python import ( # type: ignore[attr-defined] + parse_reftarget, + ) - from docutils.nodes import Text - from sphinx.addnodes import pending_xref + reftype, target, title, refspecific = parse_reftarget( + target, suppress_prefix + ) + target = __TARGET_SUBSTITUTIONS.get(target, target) + reftype = __REF_TYPE_SUBSTITUTIONS.get(target, reftype) + assert env is not None + return pending_xref( + "", + *__create_nodes(env, title), + refdomain="py", + reftype=reftype, + reftarget=target, + refspecific=refspecific, + **__get_env_kwargs(env), + ) + # Sphinx <4.4.0 + # https://github.com/sphinx-doc/sphinx/blob/v4.3.2/sphinx/domains/python.py#L83-L107 + target = __TARGET_SUBSTITUTIONS.get(target, target) + reftype = __REF_TYPE_SUBSTITUTIONS.get(target, "class") + assert env is not None return pending_xref( "", - Text(short_text), + *__create_nodes(env, target), refdomain="py", reftype=reftype, reftarget=target, - **kwargs, + **__get_env_kwargs(env), ) +def __get_env_kwargs(env: "BuildEnvironment") -> dict: + if env: + return { + "py:module": env.ref_context.get("py:module"), + "py:class": env.ref_context.get("py:class"), + } + return {} + + +def __create_nodes(env: "BuildEnvironment", title: str) -> "List[nodes.Node]": + from docutils import nodes + from sphinx.addnodes import pending_xref_condition + + short_name = title.split(".")[-1] + if env.config.python_use_unqualified_type_names: + return [ + pending_xref_condition("", short_name, condition="resolved"), + pending_xref_condition("", title, condition="*"), + ] + return [nodes.Text(short_name)] + + def relink_references() -> None: import sphinx.domains.python diff --git a/docs/conf.py b/docs/conf.py index 863348e4..2021a0d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -175,6 +175,12 @@ def fetch_logo(url: str, output_path: str) -> None: ), } autodoc_member_order = "bysource" +autodoc_type_aliases = { + "GraphElementProperties": "qrules.solving.GraphElementProperties", + "GraphSettings": "qrules.solving.GraphSettings", + "InitialFacts": "qrules.combinatorics.InitialFacts", + "StateTransition": "qrules.transition.StateTransition", +} autodoc_typehints_format = "short" codeautolink_concat_default = True AUTODOC_INSERT_SIGNATURE_LINEBREAKS = True diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 4a56790c..973a586b 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -132,7 +132,7 @@ def check_reaction_violations( # pylint: disable=too-many-arguments particle_db = load_pdg() def _check_violations( - facts: InitialFacts, + facts: "InitialFacts", node_rules: Dict[int, Set[Rule]], edge_rules: Dict[int, Set[GraphElementRule]], ) -> QNResult: diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 815c9142..35338264 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -90,9 +90,11 @@ class NodeSettings: GraphSettings = MutableTransition[EdgeSettings, NodeSettings] +"""(Mutable) mapping of settings on a `.Topology`.""" GraphElementProperties = MutableTransition[ GraphEdgePropertyMap, GraphNodePropertyMap ] +"""(Mutable) mapping of edge and node properties on a `.Topology`.""" @implement_pretty_repr @@ -106,8 +108,8 @@ class QNProblemSet: and variable domains for nodes and edges of the :attr:`topology`. """ - initial_facts: GraphElementProperties - solving_settings: GraphSettings + initial_facts: "GraphElementProperties" + solving_settings: "GraphSettings" @property def topology(self) -> Topology: diff --git a/src/qrules/transition.py b/src/qrules/transition.py index 56bfae43..de35e9e1 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -169,19 +169,14 @@ def extend( @implement_pretty_repr @define class ProblemSet: - """Particle reaction problem set, defined as a graph like data structure. - - Args: - topology: `.Topology` that contains the structure of the reaction. - initial_facts: `.InitialFacts` that contain the info of initial and - final state in connection with the topology. - solving_settings: Solving related settings such as the conservation - rules and the quantum number domains. - """ + """Particle reaction problem set as a graph-like data structure.""" topology: Topology - initial_facts: InitialFacts - solving_settings: GraphSettings + """`.Topology` over which the problem set is defined.""" + initial_facts: "InitialFacts" + """Information about the initial and final state.""" + solving_settings: "GraphSettings" + """Solving settings, such as conservation rules and QN-domains.""" def to_qn_problem_set(self) -> QNProblemSet: interactions = { @@ -412,7 +407,7 @@ def create_problem_sets(self) -> Dict[float, List[ProblemSet]]: return _group_by_strength(problem_sets) def __determine_graph_settings( - self, topology: Topology, initial_facts: InitialFacts + self, topology: Topology, initial_facts: "InitialFacts" ) -> List[GraphSettings]: # pylint: disable=too-many-locals def create_intermediate_edge_qn_domains() -> Dict: From b5a4524de24abc546e877aa351d8535e53377799 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:09:13 +0100 Subject: [PATCH 21/34] refactor: write Transition as ABC --- src/qrules/io/_dot.py | 6 +++--- src/qrules/topology.py | 30 ++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 58f99d2a..c5a0131e 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -161,7 +161,7 @@ def graph_list_to_dot( raise ValueError( "Collapsed graphs cannot be rendered with node properties" ) - graphs = _collapse_graphs(graphs) # type: ignore[assignment] + graphs = _collapse_graphs(graphs) elif strip_spin: if render_node: stripped_graphs = [] @@ -173,9 +173,9 @@ def graph_list_to_dot( stripped_graph = _strip_projections(graph) if stripped_graph not in stripped_graphs: stripped_graphs.append(stripped_graph) - graphs = stripped_graphs # type: ignore[assignment] + graphs = stripped_graphs else: - graphs = _get_particle_graphs(graphs) # type: ignore[assignment] + graphs = _get_particle_graphs(graphs) dot = "" if not isinstance(graphs, abc.Sequence): graphs = list(graphs) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index bada2177..084b7635 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -12,7 +12,7 @@ import itertools import logging import sys -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import abc from functools import total_ordering from typing import ( @@ -44,9 +44,9 @@ from qrules._implementers import implement_pretty_repr if sys.version_info >= (3, 8): - from typing import Protocol, runtime_checkable + from typing import Protocol else: - from typing_extensions import Protocol, runtime_checkable + from typing_extensions import Protocol if TYPE_CHECKING: try: @@ -656,11 +656,21 @@ def _attach_node_to_edges( """A `~typing.TypeVar` representing the type of node properties.""" -@runtime_checkable -class Transition(Protocol[EdgeType, NodeType]): - topology: Topology - states: Mapping[int, EdgeType] - interactions: Mapping[int, NodeType] +class Transition(ABC, Generic[EdgeType, NodeType]): + @property + @abstractmethod + def topology(self) -> Topology: + ... + + @property + @abstractmethod + def states(self) -> Mapping[int, EdgeType]: + ... + + @property + @abstractmethod + def interactions(self) -> Mapping[int, NodeType]: + ... NewEdgeType = TypeVar("NewEdgeType") @@ -669,7 +679,7 @@ class Transition(Protocol[EdgeType, NodeType]): @implement_pretty_repr @frozen(order=True) -class FrozenTransition(Generic[EdgeType, NodeType]): +class FrozenTransition(Transition, Generic[EdgeType, NodeType]): """Defines a frozen mapping of edge and node properties on a `Topology`.""" topology: Topology = field(validator=instance_of(Topology)) @@ -732,7 +742,7 @@ def _cast_interactions(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: @implement_pretty_repr @define -class MutableTransition(Generic[EdgeType, NodeType]): +class MutableTransition(Transition, Generic[EdgeType, NodeType]): """Graph class that resembles a frozen `.Topology` with properties. This class should contain the full information of a state transition from a From 2e5be8b0915af2f3cbfcd06c334aec734bf0de25 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:22:07 +0100 Subject: [PATCH 22/34] refactor: change initial_states etc into mixin methods --- src/qrules/topology.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 084b7635..2bd5c5ae 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -656,6 +656,10 @@ def _attach_node_to_edges( """A `~typing.TypeVar` representing the type of node properties.""" +NewEdgeType = TypeVar("NewEdgeType") +NewNodeType = TypeVar("NewNodeType") + + class Transition(ABC, Generic[EdgeType, NodeType]): @property @abstractmethod @@ -672,9 +676,20 @@ def states(self) -> Mapping[int, EdgeType]: def interactions(self) -> Mapping[int, NodeType]: ... + @property + def initial_states(self) -> Dict[int, EdgeType]: + return self.filter_states(self.topology.incoming_edge_ids) -NewEdgeType = TypeVar("NewEdgeType") -NewNodeType = TypeVar("NewNodeType") + @property + def final_states(self) -> Dict[int, EdgeType]: + return self.filter_states(self.topology.outgoing_edge_ids) + + @property + def intermediate_states(self) -> Dict[int, EdgeType]: + return self.filter_states(self.topology.intermediate_edge_ids) + + def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, EdgeType]: + return {i: self.states[i] for i in edge_ids} @implement_pretty_repr @@ -690,20 +705,8 @@ def __attrs_post_init__(self) -> None: _assert_all_defined(self.topology.nodes, self.interactions) _assert_all_defined(self.topology.edges, self.states) - @property - def initial_states(self) -> Dict[int, EdgeType]: - return self.filter_states(self.topology.incoming_edge_ids) - - @property - def final_states(self) -> Dict[int, EdgeType]: - return self.filter_states(self.topology.outgoing_edge_ids) - - @property - def intermediate_states(self) -> Dict[int, EdgeType]: - return self.filter_states(self.topology.intermediate_edge_ids) - - def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, EdgeType]: - return {i: self.states[i] for i in edge_ids} + def unfreeze(self) -> "MutableTransition[EdgeType, NodeType]": + return MutableTransition(self.topology, self.states, self.interactions) def convert( self, @@ -717,7 +720,7 @@ def convert( state_converter = lambda _: cast(NewEdgeType, _) if interaction_converter is None: interaction_converter = lambda _: cast(NewNodeType, _) - return FrozenTransition[NewEdgeType, NodeType]( + return FrozenTransition( self.topology, states={ i: state_converter(state) for i, state in self.states.items() @@ -728,9 +731,6 @@ def convert( }, ) - def unfreeze(self) -> "MutableTransition[EdgeType, NodeType]": - return MutableTransition(self.topology, self.states, self.interactions) - def _cast_states(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: return dict(obj) From 4c51903130c1d8a1e2366ffb6238070bd449e19b Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 12:29:57 +0100 Subject: [PATCH 23/34] fix: return NewNodeType in FrozenTransition convert() --- docs/_relink_references.py | 2 ++ docs/conf.py | 6 ++++-- src/qrules/topology.py | 34 ++++++++++++++++++++++++++-------- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/docs/_relink_references.py b/docs/_relink_references.py index 261a572b..39bc2941 100644 --- a/docs/_relink_references.py +++ b/docs/_relink_references.py @@ -16,6 +16,8 @@ from sphinx.environment import BuildEnvironment __TARGET_SUBSTITUTIONS = { + "EdgeType": "qrules.topology.EdgeType", + "NodeType": "qrules.topology.NodeType", "a set-like object providing a view on D's items": "typing.ItemsView", "a set-like object providing a view on D's keys": "typing.KeysView", "an object providing a view on D's values": "typing.ValuesView", diff --git a/docs/conf.py b/docs/conf.py index 2021a0d3..c5f196de 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -227,10 +227,12 @@ def fetch_logo(url: str, output_path: str) -> None: primary_domain = "py" nitpicky = True # warn if cross-references are missing nitpick_ignore = [ - ("py:class", "qrules.topology.NewEdgeType"), - ("py:class", "qrules.topology.NewNodeType"), + ("py:class", "NewEdgeType"), + ("py:class", "NewNodeType"), ("py:class", "NoneType"), ("py:class", "json.encoder.JSONEncoder"), + ("py:class", "qrules.topology.NewEdgeType"), + ("py:class", "qrules.topology.NewNodeType"), ("py:class", "typing_extensions.Protocol"), ("py:obj", "qrules.topology._K"), ("py:obj", "qrules.topology._V"), diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 2bd5c5ae..c850cab1 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -34,7 +34,7 @@ Tuple, TypeVar, ValuesView, - cast, + overload, ) import attrs @@ -708,18 +708,36 @@ def __attrs_post_init__(self) -> None: def unfreeze(self) -> "MutableTransition[EdgeType, NodeType]": return MutableTransition(self.topology, self.states, self.interactions) + @overload + def convert(self) -> "FrozenTransition[EdgeType, NodeType]": + ... + + @overload def convert( - self, - state_converter: Optional[Callable[[EdgeType], NewEdgeType]] = None, - interaction_converter: Optional[ - Callable[[NodeType], NewNodeType] - ] = None, + self, state_converter: Callable[[EdgeType], NewEdgeType] ) -> "FrozenTransition[NewEdgeType, NodeType]": + ... + + @overload + def convert( + self, *, interaction_converter: Callable[[NodeType], NewNodeType] + ) -> "FrozenTransition[EdgeType, NewNodeType]": + ... + + @overload + def convert( + self, + state_converter: Callable[[EdgeType], NewEdgeType], + interaction_converter: Callable[[NodeType], NewNodeType], + ) -> "FrozenTransition[NewEdgeType, NewNodeType]": + ... + + def convert(self, state_converter=None, interaction_converter=None): # type: ignore[no-untyped-def] # pylint: disable=unnecessary-lambda if state_converter is None: - state_converter = lambda _: cast(NewEdgeType, _) + state_converter = lambda _: _ if interaction_converter is None: - interaction_converter = lambda _: cast(NewNodeType, _) + interaction_converter = lambda _: _ return FrozenTransition( self.topology, states={ From 195151e2d79d4b86b3413d783d1682d88e803aa6 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 13:03:53 +0100 Subject: [PATCH 24/34] ci: do not fast-fail test jobs --- .github/workflows/ci-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index d1d61680..b3972c5e 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -46,6 +46,7 @@ jobs: name: Unit tests runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: - macos-11 From 140843add650084808f461684dcd2acbbe05caca Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 14:02:27 +0100 Subject: [PATCH 25/34] fix: support generic Transition class for py3.6 See https://github.com/ComPWA/qrules/pull/156#issuecomment-1046853982 --- docs/_relink_references.py | 2 ++ src/qrules/__init__.py | 12 +++-------- src/qrules/_system_control.py | 16 ++++++-------- src/qrules/combinatorics.py | 30 +++++++++++++++----------- src/qrules/io/_dict.py | 2 +- src/qrules/io/_dot.py | 20 ++++++++--------- src/qrules/solving.py | 28 ++++++++++++++---------- src/qrules/topology.py | 2 +- src/qrules/transition.py | 36 +++++++++++++++++++++---------- tests/unit/test_system_control.py | 2 +- 10 files changed, 85 insertions(+), 65 deletions(-) diff --git a/docs/_relink_references.py b/docs/_relink_references.py index 39bc2941..f541664f 100644 --- a/docs/_relink_references.py +++ b/docs/_relink_references.py @@ -22,6 +22,7 @@ "a set-like object providing a view on D's keys": "typing.KeysView", "an object providing a view on D's values": "typing.ValuesView", "typing_extensions.Protocol": "typing.Protocol", + "typing_extensions.TypeAlias": "typing.TypeAlias", } __REF_TYPE_SUBSTITUTIONS = { "None": "obj", @@ -57,6 +58,7 @@ "qrules.topology.NodeType": "obj", "qrules.topology.ValueType": "obj", "qrules.transition.StateTransition": "obj", + "typing.TypeAlias": "obj", } diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 973a586b..22fd7dc0 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -62,14 +62,8 @@ _halves_domain, _int_domain, ) -from .solving import ( - GraphSettings, - NodeSettings, - QNResult, - Rule, - validate_full_solution, -) -from .topology import create_n_body_topology +from .solving import NodeSettings, QNResult, Rule, validate_full_solution +from .topology import MutableTransition, create_n_body_topology from .transition import ( EdgeSettings, ProblemSet, @@ -139,7 +133,7 @@ def _check_violations( problem_set = ProblemSet( topology=topology, initial_facts=facts, - solving_settings=GraphSettings( + solving_settings=MutableTransition( facts.topology, interactions={ i: NodeSettings(conservation_rules=rules) diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index 26935d4d..470c40b1 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -198,11 +198,11 @@ def check( def remove_duplicate_solutions( solutions: List[ - MutableTransition[ParticleWithSpin, InteractionProperties] + "MutableTransition[ParticleWithSpin, InteractionProperties]" ], remove_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, ignore_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, -) -> List[MutableTransition[ParticleWithSpin, InteractionProperties]]: +) -> "List[MutableTransition[ParticleWithSpin, InteractionProperties]]": if remove_qns_list is None: remove_qns_list = set() if ignore_qns_list is None: @@ -232,9 +232,9 @@ def remove_duplicate_solutions( def _remove_qns_from_graph( # pylint: disable=too-many-branches - graph: MutableTransition[ParticleWithSpin, InteractionProperties], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", qn_list: Set[Type[NodeQuantumNumber]], -) -> MutableTransition[ParticleWithSpin, InteractionProperties]: +) -> "MutableTransition[ParticleWithSpin, InteractionProperties]": new_interactions = {} for node_id in graph.topology.nodes: interactions = graph.interactions[node_id] @@ -328,9 +328,7 @@ def require_interaction_property( ingoing_particle_name: str, interaction_qn: Type[NodeQuantumNumber], allowed_values: List, -) -> Callable[ - [MutableTransition[ParticleWithSpin, InteractionProperties]], bool -]: +) -> "Callable[[MutableTransition[ParticleWithSpin, InteractionProperties]], bool]": """Filter function. Closure, which can be used as a filter function in :func:`.filter_graphs`. @@ -356,7 +354,7 @@ def require_interaction_property( """ def check( - graph: MutableTransition[ParticleWithSpin, InteractionProperties] + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", ) -> bool: node_ids = _find_node_ids_with_ingoing_particle_name( graph, ingoing_particle_name @@ -375,7 +373,7 @@ def check( def _find_node_ids_with_ingoing_particle_name( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", ingoing_particle_name: str, ) -> List[int]: topology = graph.topology diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 16fc4945..36b298a3 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -5,6 +5,7 @@ extract these edge and node properties. """ +import sys from collections import OrderedDict from copy import deepcopy from itertools import permutations @@ -30,9 +31,16 @@ from .quantum_numbers import InteractionProperties, arange from .topology import MutableTransition, Topology, get_originating_node_list +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + StateWithSpins = Tuple[str, Sequence[float]] StateDefinition = Union[str, StateWithSpins] -InitialFacts = MutableTransition[ParticleWithSpin, InteractionProperties] +InitialFacts: TypeAlias = ( + "MutableTransition[ParticleWithSpin, InteractionProperties]" +) """A `.Transition` with only initial and final state information.""" @@ -254,13 +262,13 @@ def embed_in_list(some_list: List[Any]) -> List[List[Any]]: final_state=final_state, allowed_kinematic_groupings=allowed_kinematic_groupings, ) - edge_initial_facts = [] + edge_initial_facts: List[InitialFacts] = [] for kinematic_permutation in kinematic_permutation_graphs: spin_permutations = _generate_spin_permutations( kinematic_permutation, particle_db ) edge_initial_facts.extend( - [InitialFacts(topology, states=x) for x in spin_permutations] + [MutableTransition(topology, states=x) for x in spin_permutations] ) return edge_initial_facts @@ -398,19 +406,19 @@ def populate_edge_with_spin_projections( def __get_initial_state_edge_ids( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", ) -> Iterable[int]: return graph.topology.incoming_edge_ids def __get_final_state_edge_ids( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", ) -> Iterable[int]: return graph.topology.outgoing_edge_ids def match_external_edges( - graphs: List[MutableTransition[ParticleWithSpin, InteractionProperties]], + graphs: "List[MutableTransition[ParticleWithSpin, InteractionProperties]]", ) -> None: if not isinstance(graphs, list): raise TypeError("graphs argument is not of type list") @@ -424,11 +432,9 @@ def match_external_edges( def _match_external_edge_ids( # pylint: disable=too-many-locals - graphs: List[MutableTransition[ParticleWithSpin, InteractionProperties]], + graphs: "List[MutableTransition[ParticleWithSpin, InteractionProperties]]", ref_graph_id: int, - external_edge_getter_function: Callable[ - [MutableTransition], Iterable[int] - ], + external_edge_getter_function: "Callable[[MutableTransition], Iterable[int]]", ) -> None: ref_graph = graphs[ref_graph_id] # create external edge to particle mapping @@ -489,7 +495,7 @@ def perform_external_edge_identical_particle_combinatorics( def _external_edge_identical_particle_combinatorics( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", external_edge_getter_function: Callable[ [MutableTransition], Iterable[int] ], @@ -549,7 +555,7 @@ def _calculate_swappings(id_mapping: Dict[int, int]) -> OrderedDict: def _create_edge_id_particle_mapping( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", edge_ids: Iterable[int], ) -> Dict[int, str]: return {i: graph.states[i][0].name for i in edge_ids} diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index 60a2373b..47821a76 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -96,7 +96,7 @@ def build_reaction_info(definition: dict) -> ReactionInfo: def build_transition( definition: dict, -) -> FrozenTransition[State, InteractionProperties]: +) -> "FrozenTransition[State, InteractionProperties]": topology = build_topology(definition["topology"]) states_def: Dict[int, dict] = definition["states"] states: Dict[int, State] = {} diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index c5a0131e..0c507a02 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -436,7 +436,7 @@ def __extract_priority(description: str) -> int: def _get_particle_graphs( graphs: Iterable[Transition[ParticleWithSpin, InteractionProperties]], -) -> List[FrozenTransition[Particle, None]]: +) -> "List[FrozenTransition[Particle, None]]": """Strip `list` of `.Transition` s of the spin projections. Extract a `list` of `.Transition` instances with only `.Particle` instances @@ -470,10 +470,10 @@ def _get_particle_graphs( def _strip_projections( graph: Transition[Any, InteractionProperties], -) -> FrozenTransition[Particle, InteractionProperties]: +) -> "FrozenTransition[Particle, InteractionProperties]": if isinstance(graph, MutableTransition): transition = graph.freeze() - transition = cast(FrozenTransition[Any, InteractionProperties], graph) + transition = cast("FrozenTransition[Any, InteractionProperties]", graph) return transition.convert( state_converter=__to_particle, interaction_converter=lambda i: attrs.evolve( @@ -494,9 +494,9 @@ def __to_particle(state: Any) -> Particle: def _collapse_graphs( graphs: Iterable[Transition[ParticleWithSpin, InteractionProperties]], -) -> Tuple[FrozenTransition[Tuple[Particle, ...], None], ...]: - transition_groups = { - g.topology: MutableTransition[Set[Particle], None]( +) -> "Tuple[FrozenTransition[Tuple[Particle, ...], None], ...]": + transition_groups: "Dict[Topology, MutableTransition[Set[Particle], None]]" = { + g.topology: MutableTransition( g.topology, states={i: set() for i in g.topology.edges}, interactions={i: None for i in g.topology.nodes}, @@ -512,11 +512,11 @@ def _collapse_graphs( else: particle, _ = state group.states[state_id].add(particle) - particle_collection_graphs = [] + collected_graphs: "List[FrozenTransition[Tuple[Particle, ...], None]]" = [] for topology in sorted(transition_groups): group = transition_groups[topology] - particle_collection_graphs.append( - FrozenTransition[Tuple[Particle, ...], None]( + collected_graphs.append( + FrozenTransition( topology, states={ i: tuple(sorted(particles, key=lambda p: p.name)) @@ -525,4 +525,4 @@ def _collapse_graphs( interactions=group.interactions, ) ) - return tuple(particle_collection_graphs) + return tuple(collected_graphs) diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 35338264..35d045bb 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -11,6 +11,7 @@ import inspect import logging +import sys from abc import ABC, abstractmethod from collections import defaultdict from copy import copy @@ -57,6 +58,11 @@ ) from .topology import MutableTransition, Topology +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + @implement_pretty_repr @define @@ -89,11 +95,11 @@ class NodeSettings: interaction_strength: float = 1.0 -GraphSettings = MutableTransition[EdgeSettings, NodeSettings] +GraphSettings: TypeAlias = "MutableTransition[EdgeSettings, NodeSettings]" """(Mutable) mapping of settings on a `.Topology`.""" -GraphElementProperties = MutableTransition[ - GraphEdgePropertyMap, GraphNodePropertyMap -] +GraphElementProperties: TypeAlias = ( + "MutableTransition[GraphEdgePropertyMap, GraphNodePropertyMap]" +) """(Mutable) mapping of edge and node properties on a `.Topology`.""" @@ -116,9 +122,9 @@ def topology(self) -> Topology: return self.initial_facts.topology -QuantumNumberSolution = MutableTransition[ - GraphEdgePropertyMap, GraphNodePropertyMap -] +QuantumNumberSolution: TypeAlias = ( + "MutableTransition[GraphEdgePropertyMap, GraphNodePropertyMap]" +) def _convert_violated_rules_to_names( @@ -433,7 +439,7 @@ def _create_variable_containers( ) return QNResult( [ - QuantumNumberSolution( + MutableTransition( topology=problem_set.topology, states=problem_set.initial_facts.states, interactions=problem_set.initial_facts.interactions, @@ -562,10 +568,10 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: result.extend( validate_full_solution( QNProblemSet( - initial_facts=GraphElementProperties( + initial_facts=MutableTransition( topology, states, interactions ), - solving_settings=GraphSettings( + solving_settings=MutableTransition( topology, interactions={ i: NodeSettings(conservation_rules=rules) @@ -833,7 +839,7 @@ def __convert_solution_keys( else: interactions[ele_id].update({qn_type: value}) # type: ignore[dict-item] converted_solutions.append( - QuantumNumberSolution(topology, states, interactions) + MutableTransition(topology, states, interactions) ) return converted_solutions diff --git a/src/qrules/topology.py b/src/qrules/topology.py index c850cab1..2b474888 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -812,7 +812,7 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: if value2 is not None: self.states[edge_id1] = value2 - def freeze(self) -> FrozenTransition[EdgeType, NodeType]: + def freeze(self) -> "FrozenTransition[EdgeType, NodeType]": return FrozenTransition(self.topology, self.states, self.interactions) diff --git a/src/qrules/transition.py b/src/qrules/transition.py index de35e9e1..64c5f458 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -1,11 +1,13 @@ """Find allowed transitions between an initial and final state.""" import logging +import sys from collections import defaultdict from copy import copy, deepcopy from enum import Enum, auto from multiprocessing import Pool from typing import ( + TYPE_CHECKING, Dict, Iterable, List, @@ -64,7 +66,6 @@ CSPSolver, EdgeSettings, GraphEdgePropertyMap, - GraphElementProperties, GraphSettings, NodeSettings, QNProblemSet, @@ -72,13 +73,19 @@ ) from .topology import ( FrozenDict, - FrozenTransition, MutableTransition, Topology, create_isobar_topologies, create_n_body_topology, ) +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias +if TYPE_CHECKING: + from .topology import FrozenTransition # noqa: F401 + class SolvingMode(Enum): """Types of modes for solving.""" @@ -137,9 +144,9 @@ def clear(self) -> None: class _SolutionContainer: """Defines a result of a `.ProblemSet`.""" - solutions: List[ - MutableTransition[ParticleWithSpin, InteractionProperties] - ] = field(factory=list) + solutions: "List[MutableTransition[ParticleWithSpin, InteractionProperties]]" = field( + factory=list + ) execution_info: ExecutionInfo = field(default=ExecutionInfo()) def __attrs_post_init__(self) -> None: @@ -166,6 +173,10 @@ def extend( ) +if sys.version_info >= (3, 7): + attrs.resolve_types(_SolutionContainer, globals(), locals()) + + @implement_pretty_repr @define class ProblemSet: @@ -188,13 +199,16 @@ def to_qn_problem_set(self) -> QNProblemSet: for k, v in self.initial_facts.states.items() } return QNProblemSet( - initial_facts=GraphElementProperties( + initial_facts=MutableTransition( self.topology, states, interactions ), solving_settings=self.solving_settings, ) +attrs.resolve_types(ProblemSet, globals(), locals()) + + def _group_by_strength( problem_sets: List[ProblemSet], ) -> Dict[float, List[ProblemSet]]: @@ -457,7 +471,7 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: initial_state_edges = topology.incoming_edge_ids graph_settings: List[GraphSettings] = [ - GraphSettings( + MutableTransition( topology, states={ edge_id: create_edge_settings(edge_id) @@ -651,7 +665,7 @@ def __convert_result( """ solutions = [] for solution in qn_result.solutions: - graph = MutableTransition[ParticleWithSpin, InteractionProperties]( + graph = MutableTransition( topology=topology, interactions={ i: create_interaction_properties(x) @@ -689,9 +703,9 @@ def _safe_wrap_list( def _match_final_state_ids( - graph: MutableTransition[ParticleWithSpin, InteractionProperties], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", state_definition: Sequence[StateDefinition], -) -> MutableTransition[ParticleWithSpin, InteractionProperties]: +) -> "MutableTransition[ParticleWithSpin, InteractionProperties]": """Temporary fix to https://github.com/ComPWA/qrules/issues/143.""" particle_names = _strip_spin(state_definition) name_to_id = {name: i for i, name in enumerate(particle_names)} @@ -727,7 +741,7 @@ class State: spin_projection: float = field(converter=_to_float) -StateTransition = FrozenTransition[State, InteractionProperties] +StateTransition: TypeAlias = "FrozenTransition[State, InteractionProperties]" """Transition of some initial `.State` to a final `.State`.""" diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index 8bab192c..eb03f612 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -341,7 +341,7 @@ def test_filter_graphs_for_interaction_qns( def _create_graph( problem_set: ProblemSet, -) -> MutableTransition[ParticleWithSpin, InteractionProperties]: +) -> "MutableTransition[ParticleWithSpin, InteractionProperties]": return MutableTransition( topology=problem_set.topology, interactions=problem_set.initial_facts.interactions, From 071ddbeb648dc0ef32144ed60bb83e6146d14512 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 15:14:36 +0100 Subject: [PATCH 26/34] fix: ignore logo.svg in sphinx-autobuild --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index ace8b21d..f0367e36 100644 --- a/tox.ini +++ b/tox.ini @@ -62,6 +62,7 @@ commands = --re-ignore docs/.*\.yaml \ --re-ignore docs/.*\.yml \ --re-ignore docs/_build/.* \ + --re-ignore docs/_static/logo\..* \ --re-ignore docs/api/.* \ --open-browser \ docs/ docs/_build/html From eae98c9089c2e2af111bdff63c12af572bf704f9 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 15:33:22 +0100 Subject: [PATCH 27/34] refactor: hide TypeVars from topology API --- docs/_relink_references.py | 6 ------ docs/conf.py | 18 ++++++++---------- src/qrules/topology.py | 30 ++++++++++++------------------ 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/docs/_relink_references.py b/docs/_relink_references.py index f541664f..4c693c68 100644 --- a/docs/_relink_references.py +++ b/docs/_relink_references.py @@ -16,8 +16,6 @@ from sphinx.environment import BuildEnvironment __TARGET_SUBSTITUTIONS = { - "EdgeType": "qrules.topology.EdgeType", - "NodeType": "qrules.topology.NodeType", "a set-like object providing a view on D's items": "typing.ItemsView", "a set-like object providing a view on D's keys": "typing.KeysView", "an object providing a view on D's values": "typing.ValuesView", @@ -53,10 +51,6 @@ "qrules.quantum_numbers.NodeQuantumNumbers.s_projection": "obj", "qrules.solving.GraphElementProperties": "obj", "qrules.solving.GraphSettings": "obj", - "qrules.topology.EdgeType": "obj", - "qrules.topology.KeyType": "obj", - "qrules.topology.NodeType": "obj", - "qrules.topology.ValueType": "obj", "qrules.transition.StateTransition": "obj", "typing.TypeAlias": "obj", } diff --git a/docs/conf.py b/docs/conf.py index c5f196de..326b67ff 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -226,16 +226,14 @@ def fetch_logo(url: str, output_path: str) -> None: default_role = "py:obj" primary_domain = "py" nitpicky = True # warn if cross-references are missing -nitpick_ignore = [ - ("py:class", "NewEdgeType"), - ("py:class", "NewNodeType"), - ("py:class", "NoneType"), - ("py:class", "json.encoder.JSONEncoder"), - ("py:class", "qrules.topology.NewEdgeType"), - ("py:class", "qrules.topology.NewNodeType"), - ("py:class", "typing_extensions.Protocol"), - ("py:obj", "qrules.topology._K"), - ("py:obj", "qrules.topology._V"), +nitpick_ignore_regex = [ + (r"py:(class|obj)", "json.encoder.JSONEncoder"), + (r"py:(class|obj)", r"(qrules\.topology\.)?EdgeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?KT"), + (r"py:(class|obj)", r"(qrules\.topology\.)?NewEdgeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?NewNodeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?NodeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?VT"), ] diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 2b474888..19e81215 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -55,24 +55,22 @@ PrettyPrinter = Any -class Comparable(Protocol): +class _Comparable(Protocol): @abstractmethod def __lt__(self, other: Any) -> bool: ... -KeyType = TypeVar("KeyType", bound=Comparable) -"""Type the keys of the `~typing.Mapping`, see `~typing.KeysView`.""" -ValueType = TypeVar("ValueType") -"""Type the value of the `~typing.Mapping`, see `~typing.ValuesView`.""" +KT = TypeVar("KT", bound=_Comparable) +VT = TypeVar("VT") @total_ordering class FrozenDict( # pylint: disable=too-many-ancestors - abc.Hashable, abc.Mapping, Generic[KeyType, ValueType] + abc.Hashable, abc.Mapping, Generic[KT, VT] ): def __init__(self, mapping: Optional[Mapping] = None): - self.__mapping: Dict[KeyType, ValueType] = {} + self.__mapping: Dict[KT, VT] = {} if mapping is not None: self.__mapping = dict(mapping) self.__hash = hash(None) @@ -98,13 +96,13 @@ def _repr_pretty_(self, p: "PrettyPrinter", cycle: bool) -> None: p.breakable() p.text("})") - def __iter__(self) -> Iterator[KeyType]: + def __iter__(self) -> Iterator[KT]: return iter(self.__mapping) def __len__(self) -> int: return len(self.__mapping) - def __getitem__(self, key: KeyType) -> ValueType: + def __getitem__(self, key: KT) -> VT: return self.__mapping[key] def __gt__(self, other: Any) -> bool: @@ -121,19 +119,19 @@ def __gt__(self, other: Any) -> bool: def __hash__(self) -> int: return self.__hash - def keys(self) -> KeysView[KeyType]: + def keys(self) -> KeysView[KT]: return self.__mapping.keys() - def items(self) -> ItemsView[KeyType, ValueType]: + def items(self) -> ItemsView[KT, VT]: return self.__mapping.items() - def values(self) -> ValuesView[ValueType]: + def values(self) -> ValuesView[VT]: return self.__mapping.values() def _convert_mapping_to_sorted_tuple( - mapping: Mapping[KeyType, ValueType], -) -> Tuple[Tuple[KeyType, ValueType], ...]: + mapping: Mapping[KT, VT], +) -> Tuple[Tuple[KT, VT], ...]: return tuple((key, mapping[key]) for key in sorted(mapping.keys())) @@ -651,11 +649,7 @@ def _attach_node_to_edges( EdgeType = TypeVar("EdgeType") -"""A `~typing.TypeVar` representing the type of edge properties.""" NodeType = TypeVar("NodeType") -"""A `~typing.TypeVar` representing the type of node properties.""" - - NewEdgeType = TypeVar("NewEdgeType") NewNodeType = TypeVar("NewNodeType") From cc412aab5e83c132e3465f98517200dc9914a23a Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 16:05:36 +0100 Subject: [PATCH 28/34] docs: hide dict methods from API --- docs/conf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 326b67ff..9ad8260b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -164,13 +164,19 @@ def fetch_logo(url: str, output_path: str) -> None: # General sphinx settings add_module_names = False autodoc_default_options = { + "exclude-members": ", ".join( + [ + "items", + "keys", + "values", + ] + ), "members": True, "undoc-members": True, "show-inheritance": True, "special-members": ", ".join( [ "__call__", - "__getitem__", ] ), } From c5fe2c7c6507a8728b0b956176e378fc3e056344 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 16:14:45 +0100 Subject: [PATCH 29/34] fix: improve argument type hints Topology --- src/qrules/io/_dict.py | 5 +---- src/qrules/topology.py | 24 +++++++++++++++++------- tests/unit/test_topology.py | 4 ++-- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index 47821a76..d81b6bd1 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -126,10 +126,7 @@ def build_topology(definition: dict) -> Topology: nodes = definition["nodes"] edges_def: Dict[int, dict] = definition["edges"] edges = {int(i): Edge(**edge_def) for i, edge_def in edges_def.items()} - return Topology( - edges=edges, - nodes=nodes, - ) + return Topology(nodes, edges) def validate_particle_collection(instance: dict) -> None: diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 19e81215..16f7c6cb 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -39,7 +39,7 @@ import attrs from attrs import define, field, frozen -from attrs.validators import instance_of +from attrs.validators import deep_iterable, deep_mapping, instance_of from qrules._implementers import implement_pretty_repr @@ -158,10 +158,12 @@ def get_connected_nodes(self) -> Set[int]: return connected_nodes # type: ignore[return-value] -def _to_frozenset(iterable: Iterable[int]) -> FrozenSet[int]: - if not all(map(lambda i: isinstance(i, int), iterable)): - raise TypeError(f"Not all items in {iterable} are of type int") - return frozenset(iterable) +def _to_topology_nodes(inst: Iterable[int]) -> FrozenSet[int]: + return frozenset(inst) + + +def _to_topology_edges(inst: Mapping[int, Edge]) -> FrozenDict[int, Edge]: + return FrozenDict(inst) @implement_pretty_repr @@ -176,8 +178,16 @@ class Topology: like a Feynman-diagram. """ - nodes: FrozenSet[int] = field(converter=_to_frozenset) - edges: FrozenDict[int, Edge] = field(converter=FrozenDict) + nodes: FrozenSet[int] = field( + converter=_to_topology_nodes, + validator=deep_iterable(member_validator=instance_of(int)), + ) + edges: FrozenDict[int, Edge] = field( + converter=_to_topology_edges, + validator=deep_mapping( + key_validator=instance_of(int), value_validator=instance_of(Edge) + ), + ) incoming_edge_ids: FrozenSet[int] = field(init=False, repr=False) outgoing_edge_ids: FrozenSet[int] = field(init=False, repr=False) diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index bc2d8306..12ed84dd 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -177,7 +177,7 @@ class TestTopology: ], ) def test_constructor(self, nodes, edges): - topology = Topology(nodes=nodes, edges=edges) + topology = Topology(nodes, edges) if nodes is None: nodes = set() if edges is None: @@ -202,7 +202,7 @@ def test_constructor_exceptions(self, nodes, edges): r"(not connected to any other node|has non-existing node IDs)" ), ): - assert Topology(nodes=nodes, edges=edges) + assert Topology(nodes, edges) @pytest.mark.parametrize("repr_method", [repr, pretty]) def test_repr_and_eq(self, repr_method, two_to_three_decay: Topology): From cf9efab14a8cf83d8eb709da5520c8b18025c772 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 16:14:45 +0100 Subject: [PATCH 30/34] fix: improve argument type hints MutableTopology --- src/qrules/topology.py | 26 +++++++++++++++++++++----- tests/unit/test_topology.py | 4 ++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 16f7c6cb..a01633ec 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -382,10 +382,28 @@ def __get_originating_node(edge_id: int) -> Optional[int]: ] +def _to_mutable_topology_nodes(inst: Iterable[int]) -> Set[int]: + return set(inst) + + +def _to_mutable_topology_edges(inst: Mapping[int, Edge]) -> Dict[int, Edge]: + return dict(inst) + + @define class MutableTopology: - edges: Dict[int, Edge] = field(factory=dict, converter=dict) - nodes: Set[int] = field(factory=set, converter=set) + nodes: Set[int] = field( + converter=_to_mutable_topology_nodes, + factory=set, + on_setattr=deep_iterable(member_validator=instance_of(int)), + ) + edges: Dict[int, Edge] = field( + converter=_to_mutable_topology_edges, + factory=dict, + on_setattr=deep_mapping( + key_validator=instance_of(int), value_validator=instance_of(Edge) + ), + ) def freeze(self) -> Topology: return Topology( @@ -517,9 +535,7 @@ def build( seed_graph = MutableTopology() current_open_end_edges = list(range(number_of_initial_edges)) seed_graph.add_edges(current_open_end_edges) - extendable_graph_list: List[Tuple[MutableTopology, List[int]]] = [ - (seed_graph, current_open_end_edges) - ] + extendable_graph_list = [(seed_graph, current_open_end_edges)] while extendable_graph_list: active_graph_list = extendable_graph_list diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 12ed84dd..ade35763 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -102,7 +102,7 @@ class TestMutableTopology: def test_add_and_attach(self, two_to_three_decay: Topology): topology = MutableTopology( edges=two_to_three_decay.edges, - nodes=two_to_three_decay.nodes, # type: ignore[arg-type] + nodes=two_to_three_decay.nodes, ) topology.add_node(3) topology.add_edges([7, 8]) @@ -118,7 +118,7 @@ def test_add_and_attach(self, two_to_three_decay: Topology): def test_add_exceptions(self, two_to_three_decay: Topology): topology = MutableTopology( edges=two_to_three_decay.edges, - nodes=two_to_three_decay.nodes, # type: ignore[arg-type] + nodes=two_to_three_decay.nodes, ) with pytest.raises(ValueError, match=r"Node nr. 0 already exists"): topology.add_node(0) From c6ccd6fdaa4da410162a640eac277928bddb4871 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 20:59:07 +0100 Subject: [PATCH 31/34] refactor: move organize_edge_ids to MutableTopology --- src/qrules/topology.py | 125 +++++++++++++++++------------- tests/unit/io/test_dot.py | 6 +- tests/unit/test_system_control.py | 10 +-- tests/unit/test_topology.py | 113 +++++++++++++++------------ 4 files changed, 143 insertions(+), 111 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index a01633ec..b110fbf9 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -195,31 +195,34 @@ class Topology: def __attrs_post_init__(self) -> None: self.__verify() - object.__setattr__( - self, - "incoming_edge_ids", - frozenset( - edge_id - for edge_id, edge in self.edges.items() - if edge.originating_node_id is None - ), - ) - object.__setattr__( - self, - "outgoing_edge_ids", - frozenset( - edge_id - for edge_id, edge in self.edges.items() - if edge.ending_node_id is None - ), + incoming = sorted( + edge_id + for edge_id, edge in self.edges.items() + if edge.originating_node_id is None ) - object.__setattr__( - self, - "intermediate_edge_ids", - frozenset(self.edges) - ^ self.incoming_edge_ids - ^ self.outgoing_edge_ids, + outgoing = sorted( + edge_id + for edge_id, edge in self.edges.items() + if edge.ending_node_id is None ) + inter = sorted(set(self.edges) - set(incoming) - set(outgoing)) + expected = list(range(-len(incoming), 0)) + if sorted(incoming) != expected: + raise ValueError( + f"Incoming edge IDs should be {expected}, not {incoming}." + ) + n_out = len(outgoing) + expected = list(range(0, n_out)) + if sorted(outgoing) != expected: + raise ValueError( + f"Outgoing edge IDs should be {expected}, not {outgoing}." + ) + expected = list(range(n_out, n_out + len(inter))) + if sorted(inter) != expected: + raise ValueError(f"Intermediate edge IDs should be {expected}.") + object.__setattr__(self, "incoming_edge_ids", frozenset(incoming)) + object.__setattr__(self, "outgoing_edge_ids", frozenset(outgoing)) + object.__setattr__(self, "intermediate_edge_ids", frozenset(inter)) def __verify(self) -> None: """Verify if there are no dangling edges or nodes.""" @@ -314,29 +317,7 @@ def get_originating_initial_state_edge_ids(self, node_id: int) -> Set[int]: temp_edge_list = new_temp_edge_list return edge_ids - def organize_edge_ids(self) -> "Topology": - """Create a new topology with edge IDs in range :code:`[-m, n+i]`. - - where :code:`m` is the number of `.incoming_edge_ids`, :code:`n` is the - number of `.outgoing_edge_ids`, and :code:`i` is the number of - `.intermediate_edge_ids`. - - In other words, relabel the edges so that: - - - `.incoming_edge_ids` lies in the range :code:`[-1, -2, ...]` - - `.outgoing_edge_ids` lies in the range :code:`[0, 1, ..., n]` - - `.intermediate_edge_ids` lies in the range :code:`[n+1, n+2, ...]` - """ - new_to_old_id = enumerate( - tuple(self.incoming_edge_ids) - + tuple(self.outgoing_edge_ids) - + tuple(self.intermediate_edge_ids), - start=-len(self.incoming_edge_ids), - ) - old_to_new_id = {j: i for i, j in new_to_old_id} - return self.relabel_edges(old_to_new_id) - - def relabel_edges(self, old_to_new_id: Mapping[int, int]) -> "Topology": + def relabel_edges(self, old_to_new: Mapping[int, int]) -> "Topology": """Create a new `Topology` with new edge IDs. This method is particularly useful when creating permutations of a @@ -354,8 +335,10 @@ def relabel_edges(self, old_to_new_id: Mapping[int, int]) -> "Topology": >>> len(permuted_topologies) 3 """ + new_to_old = {j: i for i, j in old_to_new.items()} new_edges = { - old_to_new_id.get(i, i): edge for i, edge in self.edges.items() + old_to_new.get(i, new_to_old.get(i, i)): edge + for i, edge in self.edges.items() } return attrs.evolve(self, edges=new_edges) @@ -405,12 +388,6 @@ class MutableTopology: ), ) - def freeze(self) -> Topology: - return Topology( - edges=self.edges, - nodes=self.nodes, - ) - def add_node(self, node_id: int) -> None: """Adds a node nr. node_id. @@ -482,6 +459,43 @@ def attach_edges_to_node_outgoing( originating_node_id=node_id, ) + def organize_edge_ids(self) -> "MutableTopology": + """Organize edge IDS so that they lie in range :code:`[-m, n+i]`. + + Here, :code:`m` is the number of `.incoming_edge_ids`, :code:`n` is the + number of `.outgoing_edge_ids`, and :code:`i` is the number of + `.intermediate_edge_ids`. + + In other words, relabel the edges so that: + + - incoming edge IDs lie in the range :code:`[-1, -2, ...]`, + - outgoing edge IDs lie in the range :code:`[0, 1, ..., n]`, + - intermediate edge IDs lie in the range :code:`[n+1, n+2, ...]`. + """ + incoming = { + i + for i, edge in self.edges.items() + if edge.originating_node_id is None + } + outgoing = { + edge_id + for edge_id, edge in self.edges.items() + if edge.ending_node_id is None + } + intermediate = set(self.edges) - incoming - outgoing + new_to_old_id = enumerate( + list(incoming) + list(outgoing) + list(intermediate), + start=-len(incoming), + ) + old_to_new_id = {j: i for i, j in new_to_old_id} + new_edges = { + old_to_new_id.get(i, i): edge for i, edge in self.edges.items() + } + return attrs.evolve(self, edges=new_edges) + + def freeze(self) -> Topology: + return Topology(self.nodes, self.edges) + @define class InteractionNode: @@ -546,7 +560,6 @@ def build( len(active_graph[1]) == number_of_final_edges and len(active_graph[0].nodes) > 0 ): - active_graph[0].freeze() # verify graph_tuple_list.append(active_graph) continue @@ -556,9 +569,9 @@ def build( # strip the current open end edges list from the result graph tuples topologies = [] for graph_tuple in graph_tuple_list: - topology = graph_tuple[0].freeze() + topology = graph_tuple[0] topology = topology.organize_edge_ids() - topologies.append(topology) + topologies.append(topology.freeze()) return tuple(topologies) def _extend_graph( diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index 040824e9..4d620fc2 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -163,7 +163,11 @@ def test_write_topology(self, output_dir): output_file = output_dir + "two_body_decay_topology.gv" topology = Topology( nodes={0}, - edges={0: Edge(0, None), 1: Edge(None, 0), 2: Edge(None, 0)}, + edges={ + -1: Edge(None, 0), + 0: Edge(0, None), + 1: Edge(0, None), + }, ) io.write( instance=topology, diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index eb03f612..741f69f2 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -210,7 +210,7 @@ def make_ls_test_graph( ): topology = Topology( nodes={0}, - edges={0: Edge(None, 0)}, + edges={-1: Edge(None, 0)}, ) interactions = { 0: InteractionProperties( @@ -218,7 +218,7 @@ def make_ls_test_graph( l_magnitude=angular_momentum_magnitude, ) } - states: Dict[int, ParticleWithSpin] = {0: (particle, 0)} + states: Dict[int, ParticleWithSpin] = {-1: (particle, 0)} graph = MutableTransition(topology, states, interactions) return graph @@ -228,7 +228,7 @@ def make_ls_test_graph_scrambled( ): topology = Topology( nodes={0}, - edges={0: Edge(None, 0)}, + edges={-1: Edge(None, 0)}, ) interactions = { 0: InteractionProperties( @@ -236,7 +236,7 @@ def make_ls_test_graph_scrambled( s_magnitude=coupled_spin_magnitude, ) } - states: Dict[int, ParticleWithSpin] = {0: (particle, 0)} + states: Dict[int, ParticleWithSpin] = {-1: (particle, 0)} graph = MutableTransition(topology, states, interactions) return graph @@ -326,7 +326,7 @@ def test_filter_graphs_for_interaction_qns( tempgraph = attrs.evolve( tempgraph, states={ - 0: ( + -1: ( Particle(name=value[0], pid=0, mass=1.0, spin=1.0), 0.0, ) diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index ade35763..c11e3bf1 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -27,20 +27,20 @@ def two_to_three_decay() -> Topology: .. code-block:: - e0 -- (N0) -- e2 -- (N1) -- e3 -- (N2) -- e6 + e-1 -- (N0) -- e3 -- (N1) -- e4 -- (N2) -- e2 / \ \ - e1 e4 e5 + e-2 e0 e1 """ topology = Topology( nodes={0, 1, 2}, edges={ - 0: Edge(None, 0), - 1: Edge(None, 0), - 2: Edge(0, 1), - 3: Edge(1, 2), - 4: Edge(1, None), - 5: Edge(2, None), - 6: Edge(2, None), + -2: Edge(None, 0), + -1: Edge(None, 0), + 0: Edge(1, None), + 1: Edge(2, None), + 2: Edge(2, None), + 3: Edge(0, 1), + 4: Edge(1, 2), }, ) return topology @@ -105,15 +105,15 @@ def test_add_and_attach(self, two_to_three_decay: Topology): nodes=two_to_three_decay.nodes, ) topology.add_node(3) - topology.add_edges([7, 8]) - topology.attach_edges_to_node_outgoing([7, 8], 3) + topology.add_edges([5, 6]) + topology.attach_edges_to_node_outgoing([5, 6], 3) with pytest.raises( ValueError, match=r"Node 3 is not connected to any other node", ): topology.freeze() - topology.attach_edges_to_node_ingoing([6], 3) - assert isinstance(topology.freeze(), Topology) + topology.attach_edges_to_node_ingoing([2], 3) + assert isinstance(topology.organize_edge_ids().freeze(), Topology) def test_add_exceptions(self, two_to_three_decay: Topology): topology = MutableTopology( @@ -125,17 +125,34 @@ def test_add_exceptions(self, two_to_three_decay: Topology): with pytest.raises(ValueError, match=r"Edge nr. 0 already exists"): topology.add_edges([0]) with pytest.raises( - ValueError, match=r"Edge nr. 0 is already ingoing to node 0" + ValueError, match=r"Edge nr. -2 is already ingoing to node 0" ): - topology.attach_edges_to_node_ingoing([0], 0) + topology.attach_edges_to_node_ingoing([-2], 0) with pytest.raises(ValueError, match=r"Edge nr. 7 does not exist"): topology.attach_edges_to_node_ingoing([7], 2) with pytest.raises( - ValueError, match=r"Edge nr. 6 is already outgoing from node 2" + ValueError, match=r"Edge nr. 2 is already outgoing from node 2" ): - topology.attach_edges_to_node_outgoing([6], 2) - with pytest.raises(ValueError, match=r"Edge nr. 7 does not exist"): - topology.attach_edges_to_node_outgoing([7], 2) + topology.attach_edges_to_node_outgoing([2], 2) + with pytest.raises(ValueError, match=r"Edge nr. 5 does not exist"): + topology.attach_edges_to_node_outgoing([5], 2) + + def test_organize_edge_ids(self): + topology = MutableTopology( + nodes={0, 1, 2}, + edges={ + 0: Edge(None, 0), + 1: Edge(None, 0), + 2: Edge(1, None), + 3: Edge(2, None), + 4: Edge(2, None), + 5: Edge(0, 1), + 6: Edge(1, 2), + }, + ) + assert sorted(topology.edges) == [0, 1, 2, 3, 4, 5, 6] + topology = topology.organize_edge_ids() + assert sorted(topology.edges) == [-2, -1, 0, 1, 2, 3, 4] class TestSimpleStateTransitionTopologyBuilder: @@ -156,22 +173,22 @@ class TestTopology: ( {0, 1}, { - 0: Edge(None, 0), - 1: Edge(0, 1), - 2: Edge(1, None), - 3: Edge(1, None), + -1: Edge(None, 0), + 0: Edge(1, None), + 1: Edge(1, None), + 2: Edge(0, 1), }, ), ( {0, 1, 2}, { - 0: Edge(None, 0), - 1: Edge(0, 1), - 2: Edge(0, 2), - 3: Edge(1, None), - 4: Edge(1, None), - 5: Edge(2, None), - 6: Edge(2, None), + -1: Edge(None, 0), + 0: Edge(1, None), + 1: Edge(1, None), + 2: Edge(2, None), + 3: Edge(2, None), + 4: Edge(0, 1), + 5: Edge(0, 2), }, ), ], @@ -212,11 +229,11 @@ def test_repr_and_eq(self, repr_method, two_to_three_decay: Topology): def test_getters(self, two_to_three_decay: Topology): topology = two_to_three_decay # shorter name - assert get_originating_node_list(topology, edge_ids=[0]) == [] - assert get_originating_node_list(topology, edge_ids=[5, 6]) == [2, 2] - assert topology.incoming_edge_ids == {0, 1} - assert topology.outgoing_edge_ids == {4, 5, 6} - assert topology.intermediate_edge_ids == {2, 3} + assert topology.incoming_edge_ids == {-2, -1} + assert topology.outgoing_edge_ids == {0, 1, 2} + assert topology.intermediate_edge_ids == {3, 4} + assert get_originating_node_list(topology, edge_ids=[-1]) == [] + assert get_originating_node_list(topology, edge_ids=[1, 2]) == [2, 2] @typing.no_type_check def test_immutability(self, two_to_three_decay: Topology): @@ -234,26 +251,24 @@ def test_immutability(self, two_to_three_decay: Topology): node += 666 assert two_to_three_decay.nodes == {0, 1, 2} - def test_organize_edge_ids(self, two_to_three_decay: Topology): - topology = two_to_three_decay.organize_edge_ids() - assert topology.incoming_edge_ids == frozenset({-1, -2}) - assert topology.outgoing_edge_ids == frozenset({0, 1, 2}) - assert topology.intermediate_edge_ids == frozenset({3, 4}) - def test_relabel_edges(self, two_to_three_decay: Topology): - assert set(two_to_three_decay.edges) == {0, 1, 2, 3, 4, 5, 6} - relabeled_topology = two_to_three_decay.relabel_edges({0: -2, 1: -1}) - assert set(relabeled_topology.edges) == {-2, -1, 2, 3, 4, 5, 6} - relabeled_topology = relabeled_topology.relabel_edges({2: 0, 3: 1}) - assert set(relabeled_topology.edges) == {-2, -1, 0, 1, 4, 5, 6} + edge_ids = set(two_to_three_decay.edges) + relabeled_topology = two_to_three_decay.relabel_edges({0: 1, 1: 0}) + assert set(relabeled_topology.edges) == edge_ids + relabeled_topology = relabeled_topology.relabel_edges( + {3: 4, 4: 3, 1: 2, 2: 1} + ) + assert set(relabeled_topology.edges) == edge_ids + relabeled_topology = relabeled_topology.relabel_edges({3: 4}) + assert set(relabeled_topology.edges) == edge_ids def test_swap_edges(self, two_to_three_decay: Topology): original_topology = two_to_three_decay - topology = original_topology.swap_edges(0, 1) + topology = original_topology.swap_edges(-2, -1) assert topology == original_topology - topology = topology.swap_edges(5, 6) + topology = topology.swap_edges(1, 2) assert topology == original_topology - topology = topology.swap_edges(4, 6) + topology = topology.swap_edges(0, 1) assert topology != original_topology @pytest.mark.parametrize( From 38d7750466c2126f7a23298d6f0aacdc20acd6d7 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 22:45:37 +0100 Subject: [PATCH 32/34] refactor: sort output of create_isobar_topologies --- src/qrules/topology.py | 2 +- tests/unit/test_topology.py | 17 +++-------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/qrules/topology.py b/src/qrules/topology.py index b110fbf9..457d33c1 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -624,7 +624,7 @@ def create_isobar_topologies( number_of_initial_edges=1, number_of_final_edges=number_of_final_states, ) - return tuple(topologies) + return tuple(sorted(topologies)) def create_n_body_topology( diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index c11e3bf1..0acaac5f 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -271,21 +271,10 @@ def test_swap_edges(self, two_to_three_decay: Topology): topology = topology.swap_edges(0, 1) assert topology != original_topology - @pytest.mark.parametrize( - ("n_final_states", "expected_order"), - [ - (2, [0]), - (3, [0]), - (4, [1, 0]), - (5, [3, 2, 4, 1, 0]), - (6, [13, 9, 7, 10, 6, 5, 14, 11, 8, 15, 3, 2, 12, 4, 1, 0]), - ], - ) - def test_unique_ordering(self, n_final_states, expected_order): + @pytest.mark.parametrize("n_final_states", [2, 3, 4, 5, 6]) + def test_unique_ordering(self, n_final_states): topologies = create_isobar_topologies(n_final_states) - sorted_topologies = sorted(topologies) - order = [sorted_topologies.index(t) for t in topologies] - assert order == expected_order + assert sorted(topologies) == list(topologies) @pytest.mark.parametrize( From faecd1cb2bda3c4e434f3a1801c9266951f5d5d1 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 21:46:20 +0100 Subject: [PATCH 33/34] ci: ignore tmp files in sphinx-autobuild --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index f0367e36..79138b9b 100644 --- a/tox.ini +++ b/tox.ini @@ -54,6 +54,7 @@ commands = --watch src \ --re-ignore .*/.ipynb_checkpoints/.* \ --re-ignore .*/__pycache__/.* \ + --re-ignore .*\.tmp \ --re-ignore docs/.*\.csv \ --re-ignore docs/.*\.gv \ --re-ignore docs/.*\.inv \ From 9d8e0b2296948adfc8006c593d5a2f24135564b4 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 21 Feb 2022 15:27:34 +0100 Subject: [PATCH 34/34] docs: improve docstrings of topology module --- .flake8 | 2 + docs/.gitignore | 1 + docs/_extend_docstrings.py | 130 ++++++++++++++++++++++++ docs/conf.py | 3 + src/qrules/topology.py | 201 +++++++++++++++++++++++++++++++++---- tox.ini | 1 + 6 files changed, 318 insertions(+), 20 deletions(-) create mode 100644 docs/_extend_docstrings.py diff --git a/.flake8 b/.flake8 index 051d7edf..68479759 100644 --- a/.flake8 +++ b/.flake8 @@ -47,6 +47,8 @@ rst-roles = mod ref rst-directives = + autolink-preface + automethod deprecated envvar exception diff --git a/docs/.gitignore b/docs/.gitignore index 7bfee55a..0cd13342 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,6 +1,7 @@ *.doctree *.inv *build/ +_images/* api/ !_static/* diff --git a/docs/_extend_docstrings.py b/docs/_extend_docstrings.py new file mode 100644 index 00000000..bbcc9bb4 --- /dev/null +++ b/docs/_extend_docstrings.py @@ -0,0 +1,130 @@ +# flake8: noqa +# pylint: disable=import-error,import-outside-toplevel,invalid-name,protected-access +# pyright: reportMissingImports=false +"""Extend docstrings of the API. + +This small script is used by ``conf.py`` to dynamically modify docstrings. +""" + +import inspect +import logging +import textwrap +from typing import Callable, Dict, Optional, Type, Union + +import qrules + +logging.getLogger().setLevel(logging.ERROR) + + +def extend_docstrings() -> None: + script_name = __file__.rsplit("/", maxsplit=1)[-1] + script_name = ".".join(script_name.split(".")[:-1]) + definitions = dict(globals()) + for name, definition in definitions.items(): + module = inspect.getmodule(definition) + if module is None: + continue + if module.__name__ not in {"__main__", script_name}: + continue + if not inspect.isfunction(definition): + continue + if not name.startswith("extend_"): + continue + if name == "extend_docstrings": + continue + function_arguments = inspect.signature(definition).parameters + if len(function_arguments): + raise ValueError( + f"Local function {name} should not have a signature" + ) + definition() + + +def extend_create_isobar_topologies() -> None: + from qrules.topology import create_isobar_topologies + + topologies = qrules.topology.create_isobar_topologies(4) + dot_renderings = map( + lambda t: qrules.io.asdot(t, render_resonance_id=True), + topologies, + ) + images = [_graphviz_to_image(dot, indent=6) for dot in dot_renderings] + _append_to_docstring( + create_isobar_topologies, + f""" + + .. panels:: + :body: text-center + {images[0]} + + --- + {images[1]} + """, + ) + + +def extend_create_n_body_topology() -> None: + from qrules.topology import create_n_body_topology + + topology = create_n_body_topology( + number_of_initial_states=2, + number_of_final_states=5, + ) + dot = qrules.io.asdot(topology, render_initial_state_id=True) + _append_to_docstring( + create_n_body_topology, + _graphviz_to_image(dot, indent=4), + ) + + +def extend_Topology() -> None: + from qrules.topology import Topology, create_isobar_topologies + + topologies = create_isobar_topologies(number_of_final_states=3) + dot = qrules.io.asdot( + topologies[0], + render_initial_state_id=True, + render_resonance_id=True, + ) + _append_to_docstring( + Topology, + _graphviz_to_image(dot, indent=4), + ) + + +def _append_to_docstring( + class_type: Union[Callable, Type], appended_text: str +) -> None: + assert class_type.__doc__ is not None + class_type.__doc__ += appended_text + + +_GRAPHVIZ_COUNTER = 0 +_IMAGE_DIR = "_images" + + +def _graphviz_to_image( # pylint: disable=too-many-arguments + dot: str, + options: Optional[Dict[str, str]] = None, + format: str = "svg", + indent: int = 0, + caption: str = "", + label: str = "", +) -> str: + import graphviz # type: ignore[import] + + if options is None: + options = {} + global _GRAPHVIZ_COUNTER # pylint: disable=global-statement + output_file = f"graphviz_{_GRAPHVIZ_COUNTER}" + _GRAPHVIZ_COUNTER += 1 + graphviz.Source(dot).render(f"{_IMAGE_DIR}/{output_file}", format=format) + restructuredtext = "\n" + if label: + restructuredtext += f".. _{label}:\n" + restructuredtext += f".. figure:: /{_IMAGE_DIR}/{output_file}.{format}\n" + for option, value in options.items(): + restructuredtext += f" :{option}: {value}\n" + if caption: + restructuredtext += f"\n {caption}\n" + return textwrap.indent(restructuredtext, indent * " ") diff --git a/docs/conf.py b/docs/conf.py index 9ad8260b..39b3a95f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -73,9 +73,12 @@ def fetch_logo(url: str, output_path: str) -> None: # -- Generate API ------------------------------------------------------------ sys.path.insert(0, os.path.abspath(".")) +from _extend_docstrings import extend_docstrings # noqa: E402 from _relink_references import relink_references # noqa: E402 +extend_docstrings() relink_references() + shutil.rmtree("api", ignore_errors=True) subprocess.call( " ".join( diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 457d33c1..18fe4c65 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -1,11 +1,19 @@ -"""All modules related to topology building. +# pylint: disable=too-many-lines +"""Functionality for `Topology` and `Transition` instances. -Responsible for building all possible topologies bases on basic user input: +.. rubric:: Main interfaces -- number of initial state particles -- number of final state particles +- `Topology` and its builder functions :func:`create_isobar_topologies` and + :func:`create_n_body_topology`. +- `Transition` and its two implementations `MutableTransition` and + `FrozenTransition`. -The main interface is the `.MutableTransition`. +.. autolink-preface:: + + from qrules.topology import ( + create_isobar_topologies, + create_n_body_topology, + ) """ import copy @@ -69,6 +77,19 @@ def __lt__(self, other: Any) -> bool: class FrozenDict( # pylint: disable=too-many-ancestors abc.Hashable, abc.Mapping, Generic[KT, VT] ): + """An **immutable** and **hashable** version of a `dict`. + + `FrozenDict` makes it possible to make classes hashable if they are + decorated with :func:`attr.frozen` and contain `~typing.Mapping`-like + attributes. If these attributes were to be implemented with a normal + `dict`, the instance is strictly speaking still mutable (even if those + attributes are a `property`) and the class is therefore not safely + hashable. + + .. warning:: The keys have to be comparable, that is, they need to have a + :meth:`~object.__lt__` method. + """ + def __init__(self, mapping: Optional[Mapping] = None): self.__mapping: Dict[KT, VT] = {} if mapping is not None: @@ -143,16 +164,27 @@ def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: @frozen(order=True) class Edge: - """Struct-like definition of an edge, used in `Topology`.""" + """Struct-like definition of an edge, used in `Topology.edges`.""" originating_node_id: Optional[int] = field( default=None, converter=_to_optional_int ) + """Node ID where the `Edge` **starts**. + + An `Edge` is **incoming to** a `Topology` if its `originating_node_id` is + `None` (see `~Topology.incoming_edge_ids`). + """ ending_node_id: Optional[int] = field( default=None, converter=_to_optional_int ) + """Node ID where the `Edge` **ends**. + + An `Edge` is **outgoing from** a `Topology` if its `ending_node_id` is + `None` (see `~Topology.outgoing_edge_ids`). + """ def get_connected_nodes(self) -> Set[int]: + """Get all node IDs to which the `Edge` is connected.""" connected_nodes = {self.ending_node_id, self.originating_node_id} connected_nodes.discard(None) return connected_nodes # type: ignore[return-value] @@ -169,29 +201,65 @@ def _to_topology_edges(inst: Mapping[int, Edge]) -> FrozenDict[int, Edge]: @implement_pretty_repr @frozen(order=True) class Topology: + # noqa: D416 """Directed Feynman-like graph without edge or node properties. - Forms the underlying topology of `MutableTransition`. The graphs are - directed, meaning the edges are ingoing and outgoing to specific nodes - (since feynman graphs also have a time axis). Note that a `Topology` is not - strictly speaking a graph from graph theory, because it allows open edges, - like a Feynman-diagram. + A `Topology` is **directed** in the sense that its edges are ingoing and + outgoing to specific nodes. This is to mimic Feynman graphs, which have a + time axis. Note that a `Topology` is not strictly speaking a graph from + graph theory, because it allows open edges, like a Feynman-diagram. + + The edges and nodes can be provided with properties with a `Transition`, + which contains a `~Transition.topology`. + + As opposed to a `MutableTopology`, a `Topology` is frozen, hashable, and + ordered, so that it can be used as a kind of fingerprint for a + `Transition`. In addition, the IDs of `edges` are guaranteed to be + sequential integers and follow a specific pattern: + + - `incoming_edge_ids` (`~Transition.initial_states`) are always negative. + - `outgoing_edge_ids` (`~Transition.final_states`) lie in the range + :code:`0...n-1` with :code:`n` the number of final states. + - `intermediate_edge_ids` continue counting from :code:`n`. + + See also :meth:`MutableTopology.organize_edge_ids`. + + Example + ------- + **Isobar decay** topologies can best be created as follows: + + >>> topologies = create_isobar_topologies(number_of_final_states=3) + >>> len(topologies) + 1 + >>> topologies[0] + Topology(nodes=..., edges=...) """ nodes: FrozenSet[int] = field( converter=_to_topology_nodes, validator=deep_iterable(member_validator=instance_of(int)), ) + """A node is a point where different `edges` connect.""" edges: FrozenDict[int, Edge] = field( converter=_to_topology_edges, validator=deep_mapping( key_validator=instance_of(int), value_validator=instance_of(Edge) ), ) + """Mapping of edge IDs to their corresponding `Edge` definition.""" incoming_edge_ids: FrozenSet[int] = field(init=False, repr=False) + """Edge IDs of edges that have no `~Edge.originating_node_id`. + + `Transition.initial_states` provide properties for these edges. + """ outgoing_edge_ids: FrozenSet[int] = field(init=False, repr=False) + """Edge IDs of edges that have no `~Edge.ending_node_id`. + + `Transition.final_states` provide properties for these edges. + """ intermediate_edge_ids: FrozenSet[int] = field(init=False, repr=False) + """Edge IDs of edges that connect two `nodes`.""" def __attrs_post_init__(self) -> None: self.__verify() @@ -264,6 +332,8 @@ def is_isomorphic(self, other: "Topology") -> bool: Returns `True` if the two graphs have a one-to-one mapping of the node IDs and edge IDs. + + .. warning:: Not yet implemented. """ raise NotImplementedError @@ -375,11 +445,19 @@ def _to_mutable_topology_edges(inst: Mapping[int, Edge]) -> Dict[int, Edge]: @define class MutableTopology: + """Mutable version of a `Topology`. + + A `MutableTopology` can be used to conveniently build up a `Topology` (see + e.g. `SimpleStateTransitionTopologyBuilder`). It does not have restrictions + on the numbering of edge and node IDs. + """ + nodes: Set[int] = field( converter=_to_mutable_topology_nodes, factory=set, on_setattr=deep_iterable(member_validator=instance_of(int)), ) + """See `Topology.nodes`.""" edges: Dict[int, Edge] = field( converter=_to_mutable_topology_edges, factory=dict, @@ -387,19 +465,24 @@ class MutableTopology: key_validator=instance_of(int), value_validator=instance_of(Edge) ), ) + """See `Topology.edges`.""" def add_node(self, node_id: int) -> None: - """Adds a node nr. node_id. + """Adds a node with number :code:`node_id`. Raises: - ValueError: if node_id already exists + ValueError: if :code:`node_id` already exists in `nodes`. """ if node_id in self.nodes: raise ValueError(f"Node nr. {node_id} already exists") self.nodes.add(node_id) - def add_edges(self, edge_ids: List[int]) -> None: - """Add edges with the ids in the edge_ids list.""" + def add_edges(self, edge_ids: Iterable[int]) -> None: + """Add edges with the ids in the :code:`edge_ids` list. + + Raises: + ValueError: if :code:`edge_ids` already exist in `edges`. + """ for edge_id in edge_ids: if edge_id in self.edges: raise ValueError(f"Edge nr. {edge_id} already exists") @@ -494,6 +577,10 @@ def organize_edge_ids(self) -> "MutableTopology": return attrs.evolve(self, edges=new_edges) def freeze(self) -> Topology: + """Create an immutable `Topology` from this `MutableTopology`. + + You may need to call :meth:`organize_edge_ids` first. + """ return Topology(self.nodes, self.edges) @@ -615,6 +702,27 @@ def _extend_graph( def create_isobar_topologies( number_of_final_states: int, ) -> Tuple[Topology, ...]: + """Builder function to create a set of unique isobar decay topologies. + + Args: + number_of_final_states: The number of `~Topology.outgoing_edge_ids` + (`~.Transition.final_states`). + + Returns: + A sorted `tuple` of non-isomorphic `Topology` instances, all with the + same number of final states. + + Example: + >>> topologies = create_isobar_topologies(number_of_final_states=4) + >>> len(topologies) + 2 + >>> len(topologies[0].outgoing_edge_ids) + 4 + >>> len(set(topologies)) # hashable + 2 + >>> list(topologies) == sorted(topologies) # ordered + True + """ if number_of_final_states < 2: raise ValueError( "At least two final states required for an isobar decay" @@ -630,6 +738,31 @@ def create_isobar_topologies( def create_n_body_topology( number_of_initial_states: int, number_of_final_states: int ) -> Topology: + """Create a `Topology` that connects all edges through a single node. + + These types of ":math:`n`-body topologies" are particularly important for + :func:`.check_reaction_violations` and :mod:`.conservation_rules`. + + Args: + number_of_initial_states: The number of `~Topology.incoming_edge_ids` + (`~.Transition.initial_states`). + number_of_final_states: The number of `~Topology.outgoing_edge_ids` + (`~.Transition.final_states`). + + Example: + >>> topology = create_n_body_topology( + ... number_of_initial_states=2, + ... number_of_final_states=5, + ... ) + >>> topology + Topology(nodes=..., edges...) + >>> len(topology.nodes) + 1 + >>> len(topology.incoming_edge_ids) + 2 + >>> len(topology.outgoing_edge_ids) + 5 + """ n_in = number_of_initial_states n_out = number_of_final_states builder = SimpleStateTransitionTopologyBuilder( @@ -694,34 +827,62 @@ def _attach_node_to_edges( class Transition(ABC, Generic[EdgeType, NodeType]): + """Mapping of edge and node properties over a `.Topology`. + + This **interface** class describes a transition from an initial state to a + final state by providing a mapping of properties over the `~Topology.edges` + and `~Topology.nodes` of its `topology`. Since a `Topology` behaves like a + Feynman graph, **edges** are considered as "`states`" and **nodes** are + considered as `interactions` between those states. + + There are two implementation classes: + + - `FrozenTransition`: a complete, hashable and ordered mapping of + properties over the `~Topology.edges` and `~Topology.nodes` in its + `~FrozenTransition.topology`. + - `MutableTransition`: comparable to `MutableTopology` in that it is used + internally when finding solutions through the `.StateTransitionManager` + etc. + + These classes are also provided with **mixin** attributes `initial_states`, + `final_states`, `intermediate_states`, and :meth:`filter_states`. + """ + @property @abstractmethod def topology(self) -> Topology: + """`Topology` over which `states` and `interactions` are defined.""" ... @property @abstractmethod def states(self) -> Mapping[int, EdgeType]: + """Mapping of properties over its `topology` `~Topology.edges`.""" ... @property @abstractmethod def interactions(self) -> Mapping[int, NodeType]: + """Mapping of properties over its `topology` `~Topology.nodes`.""" ... @property def initial_states(self) -> Dict[int, EdgeType]: + """Properties for the `~Topology.incoming_edge_ids`.""" return self.filter_states(self.topology.incoming_edge_ids) @property def final_states(self) -> Dict[int, EdgeType]: + """Properties for the `~Topology.outgoing_edge_ids`.""" return self.filter_states(self.topology.outgoing_edge_ids) @property def intermediate_states(self) -> Dict[int, EdgeType]: + """Properties for the intermediate edges (connecting two nodes).""" return self.filter_states(self.topology.intermediate_edge_ids) def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, EdgeType]: + """Filter `states` by a selection of :code:`edge_ids`.""" return {i: self.states[i] for i in edge_ids} @@ -739,6 +900,7 @@ def __attrs_post_init__(self) -> None: _assert_all_defined(self.topology.edges, self.states) def unfreeze(self) -> "MutableTransition[EdgeType, NodeType]": + """Convert into a `MutableTransition`.""" return MutableTransition(self.topology, self.states, self.interactions) @overload @@ -766,6 +928,7 @@ def convert( ... def convert(self, state_converter=None, interaction_converter=None): # type: ignore[no-untyped-def] + """Cast the edge and/or node properties to another type.""" # pylint: disable=unnecessary-lambda if state_converter is None: state_converter = lambda _: _ @@ -794,12 +957,9 @@ def _cast_interactions(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: @implement_pretty_repr @define class MutableTransition(Transition, Generic[EdgeType, NodeType]): - """Graph class that resembles a frozen `.Topology` with properties. + """Mutable implementation of a `Transition`. - This class should contain the full information of a state transition from a - initial state to a final state. This information can be attached to the - nodes and edges via properties. In case not all information is provided, - error can be raised on property retrieval. + Mainly used internally by the `.StateTransitionManager` to build solutions. """ topology: Topology = field(validator=instance_of(Topology)) @@ -846,6 +1006,7 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: self.states[edge_id1] = value2 def freeze(self) -> "FrozenTransition[EdgeType, NodeType]": + """Convert into a `FrozenTransition`.""" return FrozenTransition(self.topology, self.states, self.interactions) diff --git a/tox.ini b/tox.ini index 79138b9b..c9b53d02 100644 --- a/tox.ini +++ b/tox.ini @@ -63,6 +63,7 @@ commands = --re-ignore docs/.*\.yaml \ --re-ignore docs/.*\.yml \ --re-ignore docs/_build/.* \ + --re-ignore docs/_images/.* \ --re-ignore docs/_static/logo\..* \ --re-ignore docs/api/.* \ --open-browser \