diff --git a/src/qrules/topology.py b/src/qrules/topology.py index 01ee9577..71a37e04 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -69,7 +69,7 @@ def __lt__(self, other: Any) -> bool: @total_ordering class FrozenDict( # pylint: disable=too-many-ancestors - Generic[KeyType, ValueType], abc.Hashable, abc.Mapping + abc.Hashable, abc.Mapping, Generic[KeyType, ValueType] ): def __init__(self, mapping: Optional[Mapping] = None): self.__mapping: Dict[KeyType, ValueType] = {} diff --git a/src/qrules/transition.py b/src/qrules/transition.py index c2540ebf..f334216d 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -643,6 +643,10 @@ def find_solutions( # pylint: disable=too-many-branches raise ValueError("No solutions were found") match_external_edges(final_solutions) + final_solutions = [ + _match_final_state_ids(graph, self.final_state) + for graph in final_solutions + ] return ReactionInfo.from_graphs(final_solutions, self.formalism) def _solve( @@ -699,6 +703,38 @@ def _safe_wrap_list( ) +def _match_final_state_ids( + graph: StateTransitionGraph[ParticleWithSpin], + state_definition: Sequence[StateDefinition], +) -> StateTransitionGraph[ParticleWithSpin]: + """Temporary fix to https://github.com/ComPWA/qrules/issues/143.""" + particle_names = _strip_spin(state_definition) + name_to_id = {name: i for i, name in enumerate(particle_names)} + id_remapping = { + name_to_id[graph.get_edge_props(i)[0].name]: i + for i in graph.topology.outgoing_edge_ids + } + new_topology = graph.topology.relabel_edges(id_remapping) + return StateTransitionGraph( + new_topology, + edge_props={ + i: graph.get_edge_props(id_remapping.get(i, i)) + for i in graph.topology.edges + }, + node_props={i: graph.get_node_props(i) for i in graph.topology.nodes}, + ) + + +def _strip_spin(state_definition: Sequence[StateDefinition]) -> List[str]: + particle_names = [] + for state in state_definition: + if isinstance(state, str): + particle_names.append(state) + else: + particle_names.append(state[0]) + return particle_names + + @implement_pretty_repr() @attr.frozen(order=True) class State: diff --git a/tests/unit/test_qrules.py b/tests/unit/test_qrules.py new file mode 100644 index 00000000..cdcb8af6 --- /dev/null +++ b/tests/unit/test_qrules.py @@ -0,0 +1,33 @@ +import pytest + +from qrules import generate_transitions + + +@pytest.mark.parametrize( + "resonance_names", + [ + ["Sigma(1660)~-"], + ["N(1650)+"], + ["K*(1680)~0"], + ["Sigma(1660)~-", "N(1650)+"], + ["Sigma(1660)~-", "K*(1680)~0"], + ["N(1650)+", "K*(1680)~0"], + ["Sigma(1660)~-", "N(1650)+", "K*(1680)~0"], + ], +) +def test_generate_transitions(resonance_names): + final_state_names = ["K0", "Sigma+", "p~"] + reaction = generate_transitions( + initial_state="J/psi(1S)", + final_state=final_state_names, + allowed_intermediate_particles=resonance_names, + allowed_interaction_types="strong", + ) + assert len(reaction.transition_groups) == len(resonance_names) + final_state = dict(enumerate(final_state_names)) + for transition in reaction.transitions: + this_final_state = { + i: state.particle.name + for i, state in transition.final_states.items() + } + assert final_state == this_final_state