diff --git a/docs/conf.py b/docs/conf.py index 13ac2193..6026ef84 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -47,6 +47,7 @@ def pick_newtype_attrs(some_type: type) -> list: add_module_names = False api_github_repo = f"{ORGANIZATION}/{REPO_NAME}" api_target_substitutions: dict[str, str | tuple[str, str]] = { + "EdgeQuantumNumberTypes": ("obj", "qrules.quantum_numbers.EdgeQuantumNumberTypes"), "EdgeType": "typing.TypeVar", "GraphEdgePropertyMap": ("obj", "qrules.argument_handling.GraphEdgePropertyMap"), "GraphElementProperties": ("obj", "qrules.solving.GraphElementProperties"), @@ -56,11 +57,13 @@ def pick_newtype_attrs(some_type: type) -> list: "NewEdgeType": "typing.TypeVar", "NewNodeType": "typing.TypeVar", "NodeQuantumNumber": ("obj", "qrules.quantum_numbers.NodeQuantumNumber"), + "NodeQuantumNumberTypes": ("obj", "qrules.quantum_numbers.NodeQuantumNumberTypes"), "NodeType": "typing.TypeVar", "ParticleWithSpin": ("obj", "qrules.particle.ParticleWithSpin"), "Path": "pathlib.Path", "qrules.topology.EdgeType": "typing.TypeVar", "qrules.topology.NodeType": "typing.TypeVar", + "Rule": ("obj", "qrules.argument_handling.Rule"), "SpinFormalism": ("obj", "qrules.transition.SpinFormalism"), "StateDefinition": ("obj", "qrules.combinatorics.StateDefinition"), "StateTransition": ("obj", "qrules.transition.StateTransition"), diff --git a/docs/usage/visualize.ipynb b/docs/usage/visualize.ipynb index 126b4645..d67a3255 100644 --- a/docs/usage/visualize.ipynb +++ b/docs/usage/visualize.ipynb @@ -50,6 +50,16 @@ "```" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{warning}\n", + "Currently the main user-interface is the ```StateTransitionManager```. There is work in progress to remove it and split its functionality into several functions/classes to separate concerns\n", + "and to facilitate the modification of intermediate results like the filtering of ```QNProblemSet```s, setting allowed interaction types, etc. (see below)\n", + ":::" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -103,7 +113,18 @@ "from IPython.display import display\n", "\n", "import qrules\n", + "from qrules.conservation_rules import (\n", + " parity_conservation,\n", + " spin_magnitude_conservation,\n", + " spin_validity,\n", + ")\n", "from qrules.particle import Spin\n", + "from qrules.quantum_numbers import EdgeQuantumNumbers, NodeQuantumNumbers\n", + "from qrules.solving import (\n", + " CSPSolver,\n", + " dict_set_intersection,\n", + " filter_quantum_number_problem_set,\n", + ")\n", "from qrules.topology import create_isobar_topologies, create_n_body_topology\n", "from qrules.transition import State" ] @@ -315,6 +336,102 @@ "graphviz.Source(dot)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filtering quantum number problem sets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sometimes, only a certain subset of quantum numbers and conservation rules are relevant, or the number of solutions the {class}`.StateTransitionManager` gives by default is too large for the follow-up analysis.\n", + "The {func}`.filter_quantum_number_problem_set` function can be used to produce a {class}`.QNProblemSet` where only the desired quantum numbers and conservation rules are considered when fed back to the solver." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "desired_edge_properties = {EdgeQuantumNumbers.spin_magnitude, EdgeQuantumNumbers.parity}\n", + "desired_node_properties = {\n", + " NodeQuantumNumbers.l_magnitude,\n", + " NodeQuantumNumbers.s_magnitude,\n", + "} # has to be reused in the CSPSolver-constructor\n", + "filtered_qn_problem_set = filter_quantum_number_problem_set(\n", + " qn_problem_set,\n", + " edge_rules={spin_validity},\n", + " node_rules={spin_magnitude_conservation, parity_conservation},\n", + " edge_properties=desired_edge_properties,\n", + " node_properties=desired_node_properties,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "dot = qrules.io.asdot(filtered_qn_problem_set, render_node=True)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{warning}\n", + "The next cell will use some (currently) internal functionality. As statet at the top, a workflow similar to this will be used in future versions of ```qrules```. Manual setup of the {obj}`.CSPSolver` like in here will then also not be necessary.\n", + ":::" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "solver = CSPSolver([\n", + " dict_set_intersection(\n", + " qrules.system_control.create_edge_properties(part),\n", + " desired_edge_properties,\n", + " )\n", + " for part in qrules.particle.load_pdg()\n", + "])\n", + "\n", + "filtered_qn_solutions = solver.find_solutions(filtered_qn_problem_set)\n", + "filtered_qn_result = filtered_qn_solutions.solutions[6]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "dot = qrules.io.asdot(filtered_qn_result, render_node=True)\n", + "graphviz.Source(dot)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -672,7 +789,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.9.20" } }, "nbformat": 4, diff --git a/src/qrules/argument_handling.py b/src/qrules/argument_handling.py index 95a19176..471c3ab1 100644 --- a/src/qrules/argument_handling.py +++ b/src/qrules/argument_handling.py @@ -25,6 +25,7 @@ Scalar = Union[int, float] Rule = Union[GraphElementRule, EdgeQNConservationRule, ConservationRule] +"""Any type of rule""" _ElementType = TypeVar("_ElementType") diff --git a/src/qrules/quantum_numbers.py b/src/qrules/quantum_numbers.py index a64a65af..bb1311db 100644 --- a/src/qrules/quantum_numbers.py +++ b/src/qrules/quantum_numbers.py @@ -104,7 +104,6 @@ class EdgeQuantumNumbers: edge_qn_type.__module__ = __name__ -# for static typing EdgeQuantumNumber = Union[ EdgeQuantumNumbers.pid, EdgeQuantumNumbers.mass, @@ -126,8 +125,8 @@ class EdgeQuantumNumbers: EdgeQuantumNumbers.c_parity, EdgeQuantumNumbers.g_parity, ] +"""Type hint for quantum numbers of edges""" -# for accessing the keys of the dicts in EdgeSettings EdgeQuantumNumberTypes = Union[ type[EdgeQuantumNumbers.pid], type[EdgeQuantumNumbers.mass], @@ -149,6 +148,7 @@ class EdgeQuantumNumbers: type[EdgeQuantumNumbers.c_parity], type[EdgeQuantumNumbers.g_parity], ] +"""Type-Union for accessing the keys of the dicts in `.EdgeSettings`""" @frozen(init=False) @@ -186,6 +186,7 @@ class NodeQuantumNumbers: type[NodeQuantumNumbers.s_projection], type[NodeQuantumNumbers.parity_prefactor], ] +"""Type-Union for accessing the keys of the dicts in `.NodeSettings`""" def _to_optional_float(optional_float: float | None) -> float | None: diff --git a/src/qrules/solving.py b/src/qrules/solving.py index de0e85f6..6cb4d3a6 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -101,6 +101,86 @@ def topology(self) -> Topology: return self.initial_facts.topology +def filter_quantum_number_problem_set( + quantum_number_problem_set: QNProblemSet, + edge_rules: set[GraphElementRule], + node_rules: set[Rule], + edge_properties: Iterable[EdgeQuantumNumberTypes], + node_properties: Iterable[NodeQuantumNumberTypes], +) -> QNProblemSet: + """Filter `QNProblemSet` for desired conservation rules, settings and domains. + + Currently it is the responsibility of the user to provide fitting properties + and domains for the correspinding conservation rules. + + Args: + quantum_number_problem_set: `QNProblemSet` as generated by `CSPSolver`. + edge_rules: Conservation rules regarding the edges. + node_rules: Conservation rules regarding the nodes. + edge_properties: Edge settings, properties and domains. + node_properties: Node settings, properties and domains. + """ + old_edge_settings = quantum_number_problem_set.solving_settings.states + old_node_settings = quantum_number_problem_set.solving_settings.interactions + old_edge_properties = quantum_number_problem_set.initial_facts.states + old_node_properties = quantum_number_problem_set.initial_facts.interactions + new_edge_settings = { + edge_id: EdgeSettings( + conservation_rules=edge_rules, + rule_priorities=edge_setting.rule_priorities, + qn_domains=({ + key: val + for key, val in edge_setting.qn_domains.items() + if key in set(edge_properties) + }), + ) + for edge_id, edge_setting in old_edge_settings.items() + } + new_node_settings = { + node_id: NodeSettings( + conservation_rules=node_rules, + rule_priorities=node_setting.rule_priorities, + qn_domains=({ + key: val + for key, val in node_setting.qn_domains.items() + if key in set(node_properties) + }), + ) + for node_id, node_setting in old_node_settings.items() + } + new_combined_settings = MutableTransition( + topology=quantum_number_problem_set.solving_settings.topology, + states=new_edge_settings, + interactions=new_node_settings, + ) + new_edge_properties = { + edge_id: { + edge_quantum_number: scalar + for edge_quantum_number, scalar in graph_edge_property_map.items() + if edge_quantum_number in edge_properties + } + for edge_id, graph_edge_property_map in old_edge_properties.items() + } + new_node_properties = { + node_id: { + node_quantum_number: scalar + for node_quantum_number, scalar in graph_node_property_map.items() + if node_quantum_number in node_properties + } + for node_id, graph_node_property_map in old_node_properties.items() + } + new_combined_properties = MutableTransition( + topology=quantum_number_problem_set.initial_facts.topology, + states=new_edge_properties, + interactions=new_node_properties, + ) + return attrs.evolve( + quantum_number_problem_set, + solving_settings=new_combined_settings, + initial_facts=new_combined_properties, + ) + + QuantumNumberSolution = MutableTransition[GraphEdgePropertyMap, GraphNodePropertyMap] @@ -747,6 +827,13 @@ def __convert_solution_keys( return converted_solutions +def dict_set_intersection( + base_dict: dict[Any, Any], + set_of_keys: set[Any], +) -> dict[Any, Any]: + return {key: value for key, value in base_dict.items() if key in set_of_keys} + + class Scoresheet: def __init__(self) -> None: self.__rule_calls: dict[tuple[int, Rule], int] = {} diff --git a/tests/unit/test_solving.py b/tests/unit/test_solving.py new file mode 100644 index 00000000..56b294eb --- /dev/null +++ b/tests/unit/test_solving.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import qrules.particle +import qrules.quantum_numbers +import qrules.system_control +import qrules.transition +from qrules.conservation_rules import ( + c_parity_conservation, + parity_conservation, + spin_magnitude_conservation, + spin_validity, +) +from qrules.quantum_numbers import EdgeQuantumNumbers, NodeQuantumNumbers +from qrules.solving import CSPSolver, QNProblemSet, filter_quantum_number_problem_set + +if TYPE_CHECKING: + from qrules.argument_handling import GraphEdgePropertyMap + + +def test_solve( + all_particles: qrules.particle.ParticleCollection, + quantum_number_problem_set: QNProblemSet, +) -> None: + solver = CSPSolver(all_particles) + result = solver.find_solutions(quantum_number_problem_set) + assert len(result.solutions) == 19 + + +@pytest.mark.parametrize("with_spin_projection", [True, False]) +def test_solve_with_filtered_quantum_number_problem_set( + all_particles: list[GraphEdgePropertyMap], + quantum_number_problem_set: QNProblemSet, + with_spin_projection: bool, +) -> None: + solver = CSPSolver(all_particles) + parametrized_edge_properties_and_domains = { + EdgeQuantumNumbers.pid, # had to be added for c_parity_conservation to work + EdgeQuantumNumbers.spin_magnitude, + EdgeQuantumNumbers.parity, + EdgeQuantumNumbers.c_parity, + } + if with_spin_projection: + parametrized_edge_properties_and_domains.add(EdgeQuantumNumbers.spin_projection) + + new_quantum_number_problem_set = filter_quantum_number_problem_set( + quantum_number_problem_set, + edge_rules={spin_validity}, + node_rules={ + spin_magnitude_conservation, + parity_conservation, + c_parity_conservation, + }, + edge_properties=parametrized_edge_properties_and_domains, + node_properties=( + NodeQuantumNumbers.l_magnitude, + NodeQuantumNumbers.s_magnitude, + ), + ) + result = solver.find_solutions(new_quantum_number_problem_set) + + if with_spin_projection: + assert len(result.solutions) == 319 + else: + assert len(result.solutions) == 127 + + +@pytest.fixture(scope="session") +def all_particles(): + return [ + qrules.system_control.create_edge_properties(part) + for part in qrules.particle.load_pdg() + ] + + +@pytest.fixture(scope="session") +def quantum_number_problem_set() -> QNProblemSet: + stm = qrules.StateTransitionManager( + initial_state=["psi(2S)"], + final_state=["gamma", "eta", "eta"], + formalism="helicity", + ) + problem_sets = stm.create_problem_sets() + return next( + p.to_qn_problem_set() + for strength in sorted(problem_sets) + for p in problem_sets[strength] + )