From 9e23b7a63b46ee98d73aba9b5fd361d4f01043b1 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Tue, 22 Feb 2022 00:27:28 +0100 Subject: [PATCH] refactor!: generalize STG to MutableTransition (#156) * chore: rewrite StateTransitionGraph with attrs **WARNING**: this removes the `__post_init__()` check, but this didn't seem to be doing anything anyway. Initially, I simply renamed the check `__attrs_post_init__()`, but this actually does perform the check and then the system crashes. This seems to be an issue that needs to be addressed separately. * chore: simplify (over)defined_assert functions * docs: hide TypeVars from topology API * docs: improve docstrings of topology module * ci: disable pylint line-too-long * ci: do not fast-fail test jobs * ci: ignore logo.svg in sphinx-autobuild * ci: ignore tmp files in sphinx-autobuild * docs: hide `dict.keys()` methods etc. from API * docs: improve API rendering with autodoc_type_aliases * feat: define FrozenTransition class (from StateTransition) * feat: define Transition interface class * feat: implement initial_states etc in FrozenTransition * fix: return NewNodeType in FrozenTransition convert() * refactor: change initial_states etc into mixin methods * refactor: convert StateTransition into type alias * refactor: make MutableTopology public * refactor: make NodeType of StateTransitionGraph generic * refactor: move organize_edge_ids to MutableTopology * refactor: remove GraphSettings, GraphElementProperties, QuantumNumberSolution, InitialFacts * refactor: remove ReactionInfo.from/to_graphs() * refactor: remove get_edge/node_props() * refactor: remove kw_only from MutableTopology * refactor: rename GraphSettings attrs to states/interactions * refactor: rename QuantumNumberSolution attrs to states/interactions * refactor: rename StateTransitionGraph to MutableTransition * refactor: rename edge/node_props to states/interactions * refactor: sort output of create_isobar_topologies --- .flake8 | 5 + .github/workflows/ci-tests.yml | 1 + .pylintrc | 2 + docs/.gitignore | 1 + docs/_extend_docstrings.py | 130 ++++ docs/_relink_references.py | 102 ++- docs/conf.py | 34 +- docs/index.md | 4 +- docs/usage/conservation.ipynb | 12 +- docs/usage/reaction.ipynb | 14 +- docs/usage/visualize.ipynb | 6 +- src/qrules/__init__.py | 25 +- src/qrules/_system_control.py | 117 ++-- src/qrules/argument_handling.py | 12 +- src/qrules/combinatorics.py | 73 +-- src/qrules/io/__init__.py | 26 +- src/qrules/io/_dict.py | 111 ++-- src/qrules/io/_dot.py | 287 ++++----- src/qrules/particle.py | 2 +- src/qrules/quantum_numbers.py | 2 +- src/qrules/solving.py | 182 +++--- src/qrules/topology.py | 643 +++++++++++++------ src/qrules/transition.py | 200 ++---- tests/channels/test_jpsi_to_gamma_pi0_pi0.py | 10 +- tests/unit/io/test_dot.py | 44 +- tests/unit/io/test_io.py | 8 +- tests/unit/test_combinatorics.py | 4 +- tests/unit/test_system_control.py | 88 +-- tests/unit/test_topology.py | 144 +++-- tests/unit/test_transition.py | 120 +--- tox.ini | 3 + 31 files changed, 1302 insertions(+), 1110 deletions(-) create mode 100644 docs/_extend_docstrings.py diff --git a/.flake8 b/.flake8 index b138b130..68479759 100644 --- a/.flake8 +++ b/.flake8 @@ -32,6 +32,9 @@ ignore = W503 extend-select = TI100 +per-file-ignores = + # casts with generics + src/qrules/topology.py:E731 rst-roles = attr cite @@ -44,6 +47,8 @@ rst-roles = mod ref rst-directives = + autolink-preface + automethod deprecated envvar exception diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index d1d61680..b3972c5e 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -46,6 +46,7 @@ jobs: name: Unit tests runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: - macos-11 diff --git a/.pylintrc b/.pylintrc index c32808b8..4d1ca265 100644 --- a/.pylintrc +++ b/.pylintrc @@ -16,11 +16,13 @@ ignore-patterns= disable= duplicate-code, # https://github.com/PyCQA/pylint/issues/214 invalid-unary-operand-type, # conflicts with attrs.field + line-too-long, # automatically fixed with black logging-fstring-interpolation, missing-class-docstring, # pydocstyle missing-function-docstring, # pydocstyle missing-module-docstring, # pydocstyle no-member, # conflicts with attrs.field + no-name-in-module, # already checked by mypy not-an-iterable, # conflicts with attrs.field not-callable, # conflicts with attrs.field redefined-builtin, # flake8-built diff --git a/docs/.gitignore b/docs/.gitignore index 7bfee55a..0cd13342 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,6 +1,7 @@ *.doctree *.inv *build/ +_images/* api/ !_static/* diff --git a/docs/_extend_docstrings.py b/docs/_extend_docstrings.py new file mode 100644 index 00000000..bbcc9bb4 --- /dev/null +++ b/docs/_extend_docstrings.py @@ -0,0 +1,130 @@ +# flake8: noqa +# pylint: disable=import-error,import-outside-toplevel,invalid-name,protected-access +# pyright: reportMissingImports=false +"""Extend docstrings of the API. + +This small script is used by ``conf.py`` to dynamically modify docstrings. +""" + +import inspect +import logging +import textwrap +from typing import Callable, Dict, Optional, Type, Union + +import qrules + +logging.getLogger().setLevel(logging.ERROR) + + +def extend_docstrings() -> None: + script_name = __file__.rsplit("/", maxsplit=1)[-1] + script_name = ".".join(script_name.split(".")[:-1]) + definitions = dict(globals()) + for name, definition in definitions.items(): + module = inspect.getmodule(definition) + if module is None: + continue + if module.__name__ not in {"__main__", script_name}: + continue + if not inspect.isfunction(definition): + continue + if not name.startswith("extend_"): + continue + if name == "extend_docstrings": + continue + function_arguments = inspect.signature(definition).parameters + if len(function_arguments): + raise ValueError( + f"Local function {name} should not have a signature" + ) + definition() + + +def extend_create_isobar_topologies() -> None: + from qrules.topology import create_isobar_topologies + + topologies = qrules.topology.create_isobar_topologies(4) + dot_renderings = map( + lambda t: qrules.io.asdot(t, render_resonance_id=True), + topologies, + ) + images = [_graphviz_to_image(dot, indent=6) for dot in dot_renderings] + _append_to_docstring( + create_isobar_topologies, + f""" + + .. panels:: + :body: text-center + {images[0]} + + --- + {images[1]} + """, + ) + + +def extend_create_n_body_topology() -> None: + from qrules.topology import create_n_body_topology + + topology = create_n_body_topology( + number_of_initial_states=2, + number_of_final_states=5, + ) + dot = qrules.io.asdot(topology, render_initial_state_id=True) + _append_to_docstring( + create_n_body_topology, + _graphviz_to_image(dot, indent=4), + ) + + +def extend_Topology() -> None: + from qrules.topology import Topology, create_isobar_topologies + + topologies = create_isobar_topologies(number_of_final_states=3) + dot = qrules.io.asdot( + topologies[0], + render_initial_state_id=True, + render_resonance_id=True, + ) + _append_to_docstring( + Topology, + _graphviz_to_image(dot, indent=4), + ) + + +def _append_to_docstring( + class_type: Union[Callable, Type], appended_text: str +) -> None: + assert class_type.__doc__ is not None + class_type.__doc__ += appended_text + + +_GRAPHVIZ_COUNTER = 0 +_IMAGE_DIR = "_images" + + +def _graphviz_to_image( # pylint: disable=too-many-arguments + dot: str, + options: Optional[Dict[str, str]] = None, + format: str = "svg", + indent: int = 0, + caption: str = "", + label: str = "", +) -> str: + import graphviz # type: ignore[import] + + if options is None: + options = {} + global _GRAPHVIZ_COUNTER # pylint: disable=global-statement + output_file = f"graphviz_{_GRAPHVIZ_COUNTER}" + _GRAPHVIZ_COUNTER += 1 + graphviz.Source(dot).render(f"{_IMAGE_DIR}/{output_file}", format=format) + restructuredtext = "\n" + if label: + restructuredtext += f".. _{label}:\n" + restructuredtext += f".. figure:: /{_IMAGE_DIR}/{output_file}.{format}\n" + for option, value in options.items(): + restructuredtext += f" :{option}: {value}\n" + if caption: + restructuredtext += f"\n {caption}\n" + return textwrap.indent(restructuredtext, indent * " ") diff --git a/docs/_relink_references.py b/docs/_relink_references.py index 65cb1ac0..4c693c68 100644 --- a/docs/_relink_references.py +++ b/docs/_relink_references.py @@ -8,21 +8,51 @@ See also https://github.com/sphinx-doc/sphinx/issues/5868. """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List if TYPE_CHECKING: + from docutils import nodes from sphinx.addnodes import pending_xref from sphinx.environment import BuildEnvironment - __TARGET_SUBSTITUTIONS = { "a set-like object providing a view on D's items": "typing.ItemsView", "a set-like object providing a view on D's keys": "typing.KeysView", "an object providing a view on D's values": "typing.ValuesView", "typing_extensions.Protocol": "typing.Protocol", + "typing_extensions.TypeAlias": "typing.TypeAlias", } __REF_TYPE_SUBSTITUTIONS = { "None": "obj", + "qrules.combinatorics.InitialFacts": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.baryon_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.bottomness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.c_parity": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.charge": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.charmness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.electron_lepton_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.g_parity": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.isospin_magnitude": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.isospin_projection": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.mass": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.muon_lepton_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.parity": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.pid": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.spin_magnitude": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.spin_projection": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.strangeness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.tau_lepton_number": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.topness": "obj", + "qrules.quantum_numbers.EdgeQuantumNumbers.width": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.l_magnitude": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.l_projection": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.parity_prefactor": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.s_magnitude": "obj", + "qrules.quantum_numbers.NodeQuantumNumbers.s_projection": "obj", + "qrules.solving.GraphElementProperties": "obj", + "qrules.solving.GraphSettings": "obj", + "qrules.transition.StateTransition": "obj", + "typing.TypeAlias": "obj", } @@ -31,34 +61,68 @@ def _new_type_to_xref( env: "BuildEnvironment" = None, suppress_prefix: bool = False, ) -> "pending_xref": - if env: - kwargs = { - "py:module": env.ref_context.get("py:module"), - "py:class": env.ref_context.get("py:class"), - } - else: - kwargs = {} + import sphinx + from sphinx.addnodes import pending_xref - target = __TARGET_SUBSTITUTIONS.get(target, target) - reftype = __REF_TYPE_SUBSTITUTIONS.get(target, "class") - if suppress_prefix: - short_text = target.split(".")[-1] - else: - short_text = target + if sphinx.version_info >= (4, 4): + # https://github.com/sphinx-doc/sphinx/blob/v4.4.0/sphinx/domains/python.py#L110-L133 + from sphinx.domains.python import ( # type: ignore[attr-defined] + parse_reftarget, + ) - from docutils.nodes import Text - from sphinx.addnodes import pending_xref + reftype, target, title, refspecific = parse_reftarget( + target, suppress_prefix + ) + target = __TARGET_SUBSTITUTIONS.get(target, target) + reftype = __REF_TYPE_SUBSTITUTIONS.get(target, reftype) + assert env is not None + return pending_xref( + "", + *__create_nodes(env, title), + refdomain="py", + reftype=reftype, + reftarget=target, + refspecific=refspecific, + **__get_env_kwargs(env), + ) + # Sphinx <4.4.0 + # https://github.com/sphinx-doc/sphinx/blob/v4.3.2/sphinx/domains/python.py#L83-L107 + target = __TARGET_SUBSTITUTIONS.get(target, target) + reftype = __REF_TYPE_SUBSTITUTIONS.get(target, "class") + assert env is not None return pending_xref( "", - Text(short_text), + *__create_nodes(env, target), refdomain="py", reftype=reftype, reftarget=target, - **kwargs, + **__get_env_kwargs(env), ) +def __get_env_kwargs(env: "BuildEnvironment") -> dict: + if env: + return { + "py:module": env.ref_context.get("py:module"), + "py:class": env.ref_context.get("py:class"), + } + return {} + + +def __create_nodes(env: "BuildEnvironment", title: str) -> "List[nodes.Node]": + from docutils import nodes + from sphinx.addnodes import pending_xref_condition + + short_name = title.split(".")[-1] + if env.config.python_use_unqualified_type_names: + return [ + pending_xref_condition("", short_name, condition="resolved"), + pending_xref_condition("", title, condition="*"), + ] + return [nodes.Text(short_name)] + + def relink_references() -> None: import sphinx.domains.python diff --git a/docs/conf.py b/docs/conf.py index 59faa962..39b3a95f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -73,9 +73,12 @@ def fetch_logo(url: str, output_path: str) -> None: # -- Generate API ------------------------------------------------------------ sys.path.insert(0, os.path.abspath(".")) +from _extend_docstrings import extend_docstrings # noqa: E402 from _relink_references import relink_references # noqa: E402 +extend_docstrings() relink_references() + shutil.rmtree("api", ignore_errors=True) subprocess.call( " ".join( @@ -164,17 +167,29 @@ def fetch_logo(url: str, output_path: str) -> None: # General sphinx settings add_module_names = False autodoc_default_options = { + "exclude-members": ", ".join( + [ + "items", + "keys", + "values", + ] + ), "members": True, "undoc-members": True, "show-inheritance": True, "special-members": ", ".join( [ "__call__", - "__getitem__", ] ), } autodoc_member_order = "bysource" +autodoc_type_aliases = { + "GraphElementProperties": "qrules.solving.GraphElementProperties", + "GraphSettings": "qrules.solving.GraphSettings", + "InitialFacts": "qrules.combinatorics.InitialFacts", + "StateTransition": "qrules.transition.StateTransition", +} autodoc_typehints_format = "short" codeautolink_concat_default = True AUTODOC_INSERT_SIGNATURE_LINEBREAKS = True @@ -220,15 +235,14 @@ def fetch_logo(url: str, output_path: str) -> None: default_role = "py:obj" primary_domain = "py" nitpicky = True # warn if cross-references are missing -nitpick_ignore = [ - ("py:class", "EdgeType"), - ("py:class", "NoneType"), - ("py:class", "StateTransitionGraph"), - ("py:class", "ValueType"), - ("py:class", "json.encoder.JSONEncoder"), - ("py:class", "typing_extensions.Protocol"), - ("py:obj", "qrules.topology._K"), - ("py:obj", "qrules.topology._V"), +nitpick_ignore_regex = [ + (r"py:(class|obj)", "json.encoder.JSONEncoder"), + (r"py:(class|obj)", r"(qrules\.topology\.)?EdgeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?KT"), + (r"py:(class|obj)", r"(qrules\.topology\.)?NewEdgeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?NewNodeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?NodeType"), + (r"py:(class|obj)", r"(qrules\.topology\.)?VT"), ] diff --git a/docs/index.md b/docs/index.md index 49bee108..e50e88b1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,7 +47,7 @@ QRules consists of three major components: 1. **State transition graphs** - A {class}`.StateTransitionGraph` is a + A {class}`.MutableTransition` is a [directed graph](https://en.wikipedia.org/wiki/Directed_graph) that consists of **nodes** and **edges**. In a directed graph, each edge must be connected to at least one node (in correspondence to @@ -93,7 +93,7 @@ The main solver used by {mod}`qrules` is the 1. **Preparation** 1.1. Build all possible topologies. A **topology** is represented by a - {class}`.StateTransitionGraph`, in which the edges and nodes are empty (no + {class}`.MutableTransition`, in which the edges and nodes are empty (no particle information). 1.2. Fill the topology graphs with the user provided information. Typically diff --git a/docs/usage/conservation.ipynb b/docs/usage/conservation.ipynb index 6c6b64f7..d30d9898 100644 --- a/docs/usage/conservation.ipynb +++ b/docs/usage/conservation.ipynb @@ -97,7 +97,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "QRules generates {class}`.StateTransitionGraph`s, populates them with quantum numbers (edge properties representing states and nodes properties representing interactions), then checks whether the generated {class}`.StateTransitionGraph`s comply with the rules formulated in the {mod}`.conservation_rules` module.\n", + "QRules generates {class}`.MutableTransition`s, populates them with quantum numbers (edge properties representing states and nodes properties representing interactions), then checks whether the generated {class}`.MutableTransition`s comply with the rules formulated in the {mod}`.conservation_rules` module.\n", "\n", "The {mod}`.conservation_rules` module can also be used separately. In this notebook, we will illustrate this by checking spin and parity conservation." ] @@ -215,14 +215,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Example with a {class}`.StateTransition`" + "## Example with a {obj}`.StateTransition`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "First, generate some {class}`.StateTransition`s with {func}`.generate_transitions`, then select one of them:" + "First, generate some {obj}`.StateTransition`s with {func}`.generate_transitions`, then select one of them:" ] }, { @@ -361,14 +361,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Modifying {class}`.StateTransition`s" + "### Modifying {obj}`.StateTransition`s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "When checking conservation rules, you may want to modify the properties on the {class}`.StateTransition`s. However, a {class}`.StateTransition` is frozen, so it is not possible to modify its {attr}`~.StateTransition.interactions` and {attr}`~.StateTransition.states`. The only way around this is to create a new instance with {func}`attrs.evolve`.\n", + "When checking conservation rules, you may want to modify the properties on the {obj}`.StateTransition`s. However, a {obj}`.StateTransition` is a {class}`.FrozenTransition`, so it is not possible to modify its {attr}`~.FrozenTransition.interactions` and {attr}`~.FrozenTransition.states`. The only way around this is to create a new instance with {func}`attrs.evolve`.\n", "\n", "First, we get the instance (in this case one of the {class}`.InteractionProperties`) and substitute its {attr}`.InteractionProperties.l_magnitude`:" ] @@ -387,7 +387,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We then again use {func}`attrs.evolve` to substitute the {attr}`.StateTransition.interactions` of the original {class}`.StateTransition`:" + "We then again use {func}`attrs.evolve` to substitute the {attr}`.Transition.interactions` of the original {obj}`.StateTransition`:" ] }, { diff --git a/docs/usage/reaction.ipynb b/docs/usage/reaction.ipynb index 41830dd5..7bf0d4fc 100644 --- a/docs/usage/reaction.ipynb +++ b/docs/usage/reaction.ipynb @@ -179,7 +179,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Create all {class}`.ProblemSet`'s using the boundary conditions of the {class}`.StateTransitionManager` instance. By default it uses the **isobar model** (tree of two-body decays) to build {class}`.Topology`'s. Various {class}`.InitialFacts` are created for each topology based on the initial and final state. Lastly some reasonable default settings for the solving process are chosen. Remember that each interaction node defines its own set of conservation laws." + "Create all {class}`.ProblemSet`'s using the boundary conditions of the {class}`.StateTransitionManager` instance. By default it uses the **isobar model** (tree of two-body decays) to build {class}`.Topology`'s. Various {obj}`.InitialFacts` are created for each topology based on the initial and final state. Lastly some reasonable default settings for the solving process are chosen. Remember that each interaction node defines its own set of conservation laws." ] }, { @@ -247,12 +247,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Each {class}`.ProblemSet` provides a mapping of {attr}`~.ProblemSet.initial_facts` that represent the initial and final states with spin projections. The nodes and edges in between these {attr}`~.ProblemSet.initial_facts` are still to be generated. This will be done from the provided {attr}`~.ProblemSet.solving_settings` ({class}`~.GraphSettings`). There are two mechanisms there:\n", + "Each {class}`.ProblemSet` provides a mapping of {attr}`~.ProblemSet.initial_facts` that represent the initial and final states with spin projections. The nodes and edges in between these {attr}`~.ProblemSet.initial_facts` are still to be generated. This will be done from the provided {attr}`~.ProblemSet.solving_settings`. There are two mechanisms there:\n", "\n", - "1. One the one hand, the {attr}`.EdgeSettings.qn_domains` and {attr}`.NodeSettings.qn_domains` contained in the {class}`~.GraphSettings` define the **domain** over which quantum number sets can be generated.\n", - "2. On the other, the {attr}`.EdgeSettings.rule_priorities` and {attr}`.NodeSettings.rule_priorities` in {class}`~.GraphSettings` define which **{mod}`.conservation_rules`** are used to determine which of the sets of generated quantum numbers are valid.\n", + "1. One the one hand, the {attr}`.EdgeSettings.qn_domains` and {attr}`.NodeSettings.qn_domains` contained in the {attr}`~.ProblemSet.solving_settings` define the **domain** over which quantum number sets can be generated.\n", + "2. On the other, the {attr}`.EdgeSettings.rule_priorities` and {attr}`.NodeSettings.rule_priorities` in {attr}`~.ProblemSet.solving_settings` define which **{mod}`.conservation_rules`** are used to determine which of the sets of generated quantum numbers are valid.\n", "\n", - "Together, these two constraints allow the {class}`.StateTransitionManager` to generate a number of {class}`.StateTransitionGraph`s that comply with the selected {mod}`.conservation_rules`." + "Together, these two constraints allow the {class}`.StateTransitionManager` to generate a number of {class}`.MutableTransition`s that comply with the selected {mod}`.conservation_rules`." ] }, { @@ -313,7 +313,7 @@ "class: dropdown\n", "----\n", "\n", - "The \"number of {attr}`~.ReactionInfo.transitions`\" is the total number of allowed {obj}`.StateTransitionGraph` instances that the {class}`.StateTransitionManager` has found. This also includes all allowed **spin projection combinations**. In this channel, we for example consider a $J/\\psi$ with spin projection $\\pm1$ that decays into a $\\gamma$ with spin projection $\\pm1$, which already gives us four possibilities.\n", + "The \"number of {attr}`~.ReactionInfo.transitions`\" is the total number of allowed {obj}`.MutableTransition` instances that the {class}`.StateTransitionManager` has found. This also includes all allowed **spin projection combinations**. In this channel, we for example consider a $J/\\psi$ with spin projection $\\pm1$ that decays into a $\\gamma$ with spin projection $\\pm1$, which already gives us four possibilities.\n", "\n", "On the other hand, the intermediate state names that was extracted with {meth}`.ReactionInfo.get_intermediate_particles`, is just a {obj}`set` of the state names on the intermediate edges of the list of {attr}`~.ReactionInfo.transitions`, regardless of spin projection.\n", "````" @@ -457,7 +457,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {class}`.ReactionInfo`, {class}`.StateTransitionGraph`, and {class}`.Topology` can be serialized to and from a {obj}`dict` with {func}`.io.asdict` and {func}`.io.fromdict`:" + "The {class}`.ReactionInfo`, {class}`.MutableTransition`, and {class}`.Topology` can be serialized to and from a {obj}`dict` with {func}`.io.asdict` and {func}`.io.fromdict`:" ] }, { diff --git a/docs/usage/visualize.ipynb b/docs/usage/visualize.ipynb index eb663fb3..0dc90243 100644 --- a/docs/usage/visualize.ipynb +++ b/docs/usage/visualize.ipynb @@ -81,7 +81,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {mod}`~qrules.io` module allows you to convert {class}`.StateTransitionGraph`, {class}`.Topology` instances, and {class}`.ProblemSet`s to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.ReactionInfo` object with a {class}`.list` of {class}`.StateTransitionGraph` instances (see {doc}`/usage/reaction`)." + "The {mod}`~qrules.io` module allows you to convert {class}`.MutableTransition`, {class}`.Topology` instances, and {class}`.ProblemSet`s to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.ReactionInfo` object with a {class}`.list` of {class}`.MutableTransition` instances (see {doc}`/usage/reaction`)." ] }, { @@ -255,7 +255,7 @@ "tags": [] }, "source": [ - "## {class}`.StateTransition`s" + "## {obj}`.StateTransition`s" ] }, { @@ -300,7 +300,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This {class}`str` of [DOT language](https://graphviz.org/doc/info/lang.html) for the list of {class}`.StateTransitionGraph` instances can then be visualized with a third-party library, for instance, with {class}`graphviz.Source`:" + "This {class}`str` of [DOT language](https://graphviz.org/doc/info/lang.html) for the list of {class}`.MutableTransition` instances can then be visualized with a third-party library, for instance, with {class}`graphviz.Source`:" ] }, { diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 71c5c990..22fd7dc0 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -7,9 +7,9 @@ and allowed interactions. The core of `qrules` computes which transitions (represented by a -`.StateTransitionGraph`) are allowed between a certain initial and final state. +`.MutableTransition`) are allowed between a certain initial and final state. Internally, the system propagates the quantum numbers defined by the -`particle` module through the `.StateTransitionGraph`, while +`particle` module through the `.MutableTransition`, while satisfying the rules define by the :mod:`.conservation_rules` module. See :doc:`/usage/reaction` and :doc:`/usage/particle`. @@ -62,14 +62,8 @@ _halves_domain, _int_domain, ) -from .solving import ( - GraphSettings, - NodeSettings, - QNResult, - Rule, - validate_full_solution, -) -from .topology import create_n_body_topology +from .solving import NodeSettings, QNResult, Rule, validate_full_solution +from .topology import MutableTransition, create_n_body_topology from .transition import ( EdgeSettings, ProblemSet, @@ -132,19 +126,20 @@ def check_reaction_violations( # pylint: disable=too-many-arguments particle_db = load_pdg() def _check_violations( - facts: InitialFacts, + facts: "InitialFacts", node_rules: Dict[int, Set[Rule]], edge_rules: Dict[int, Set[GraphElementRule]], ) -> QNResult: problem_set = ProblemSet( topology=topology, initial_facts=facts, - solving_settings=GraphSettings( - node_settings={ + solving_settings=MutableTransition( + facts.topology, + interactions={ i: NodeSettings(conservation_rules=rules) for i, rules in node_rules.items() }, - edge_settings={ + states={ i: EdgeSettings(conservation_rules=rules) for i, rules in edge_rules.items() }, @@ -239,7 +234,7 @@ def check_edge_qn_conservation() -> Set[FrozenSet[str]]: for facts_combination in initial_facts: new_facts = attrs.evolve( facts_combination, - node_props={node_id: ls_combi}, + interactions={node_id: ls_combi}, ) initial_facts_list.append(new_facts) diff --git a/src/qrules/_system_control.py b/src/qrules/_system_control.py index 3b9a2c3e..470c40b1 100644 --- a/src/qrules/_system_control.py +++ b/src/qrules/_system_control.py @@ -17,12 +17,12 @@ ) from .settings import InteractionType from .solving import GraphEdgePropertyMap, GraphNodePropertyMap, GraphSettings -from .topology import StateTransitionGraph +from .topology import MutableTransition Strength = float GraphSettingsGroups = Dict[ - Strength, List[Tuple[StateTransitionGraph, GraphSettings]] + Strength, List[Tuple[MutableTransition, GraphSettings]] ] @@ -59,7 +59,7 @@ def create_edge_properties( def create_node_properties( - node_props: InteractionProperties, + interactions: InteractionProperties, ) -> GraphNodePropertyMap: node_qn_mapping: Dict[str, Type[NodeQuantumNumber]] = { qn_name: qn_type @@ -67,7 +67,7 @@ def create_node_properties( if not qn_name.startswith("__") } # Note using attrs.fields does not work here because init=False property_map: GraphNodePropertyMap = {} - for qn_name, value in attrs.asdict(node_props).items(): + for qn_name, value in attrs.asdict(interactions).items(): if value is None: continue if qn_name in node_qn_mapping: @@ -82,7 +82,7 @@ def create_node_properties( def create_particle( - edge_props: GraphEdgePropertyMap, particle_db: ParticleCollection + states: GraphEdgePropertyMap, particle_db: ParticleCollection ) -> ParticleWithSpin: """Create a Particle with spin projection from a qn dictionary. @@ -90,7 +90,7 @@ def create_particle( particle inside the `.ParticleCollection`. Args: - edge_props: The quantum number dictionary. + states: The quantum number dictionary. particle_db: A `.ParticleCollection` which is used to retrieve a reference `.particle` to lower the memory footprint. @@ -100,13 +100,13 @@ def create_particle( ValueError: If the edge properties do not contain spin projection info. """ - particle = particle_db.find(int(edge_props[EdgeQuantumNumbers.pid])) - if EdgeQuantumNumbers.spin_projection not in edge_props: + particle = particle_db.find(int(states[EdgeQuantumNumbers.pid])) + if EdgeQuantumNumbers.spin_projection not in states: raise ValueError( f"{GraphEdgePropertyMap.__name__} does not contain a spin" " projection" ) - spin_projection = edge_props[EdgeQuantumNumbers.spin_projection] + spin_projection = states[EdgeQuantumNumbers.spin_projection] return (particle, spin_projection) @@ -150,9 +150,9 @@ class InteractionDeterminator(ABC): @abstractmethod def check( self, - in_edge_props: List[ParticleWithSpin], - out_edge_props: List[ParticleWithSpin], - node_props: InteractionProperties, + in_states: List[ParticleWithSpin], + out_states: List[ParticleWithSpin], + interactions: InteractionProperties, ) -> List[InteractionType]: pass @@ -162,12 +162,12 @@ class GammaCheck(InteractionDeterminator): def check( self, - in_edge_props: List[ParticleWithSpin], - out_edge_props: List[ParticleWithSpin], - node_props: InteractionProperties, + in_states: List[ParticleWithSpin], + out_states: List[ParticleWithSpin], + interactions: InteractionProperties, ) -> List[InteractionType]: int_types = list(InteractionType) - for particle, _ in in_edge_props + out_edge_props: + for particle, _ in in_states + out_states: if "gamma" in particle.name: int_types = [InteractionType.EM] break @@ -179,12 +179,12 @@ class LeptonCheck(InteractionDeterminator): def check( self, - in_edge_props: List[ParticleWithSpin], - out_edge_props: List[ParticleWithSpin], - node_props: InteractionProperties, + in_states: List[ParticleWithSpin], + out_states: List[ParticleWithSpin], + interactions: InteractionProperties, ) -> List[InteractionType]: node_interaction_types = list(InteractionType) - for particle, _ in in_edge_props + out_edge_props: + for particle, _ in in_states + out_states: if particle.is_lepton(): if particle.name.startswith("nu("): node_interaction_types = [InteractionType.WEAK] @@ -197,10 +197,12 @@ def check( def remove_duplicate_solutions( - solutions: List[StateTransitionGraph[ParticleWithSpin]], + solutions: List[ + "MutableTransition[ParticleWithSpin, InteractionProperties]" + ], remove_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, ignore_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, -) -> List[StateTransitionGraph[ParticleWithSpin]]: +) -> "List[MutableTransition[ParticleWithSpin, InteractionProperties]]": if remove_qns_list is None: remove_qns_list = set() if ignore_qns_list is None: @@ -209,7 +211,9 @@ def remove_duplicate_solutions( logging.info(f"removing these qns from graphs: {remove_qns_list}") logging.info(f"ignoring qns in graph comparison: {ignore_qns_list}") - filtered_solutions: List[StateTransitionGraph[ParticleWithSpin]] = [] + filtered_solutions: List[ + MutableTransition[ParticleWithSpin, InteractionProperties] + ] = [] remove_counter = 0 for sol_graph in solutions: sol_graph = _remove_qns_from_graph(sol_graph, remove_qns_list) @@ -228,37 +232,35 @@ def remove_duplicate_solutions( def _remove_qns_from_graph( # pylint: disable=too-many-branches - graph: StateTransitionGraph[ParticleWithSpin], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", qn_list: Set[Type[NodeQuantumNumber]], -) -> StateTransitionGraph[ParticleWithSpin]: - new_node_props = {} +) -> "MutableTransition[ParticleWithSpin, InteractionProperties]": + new_interactions = {} for node_id in graph.topology.nodes: - node_props = graph.get_node_props(node_id) - new_node_props[node_id] = attrs.evolve( - node_props, **{x.__name__: None for x in qn_list} + interactions = graph.interactions[node_id] + new_interactions[node_id] = attrs.evolve( + interactions, **{x.__name__: None for x in qn_list} ) - return graph.evolve(node_props=new_node_props) + return attrs.evolve(graph, interactions=new_interactions) def _check_equal_ignoring_qns( - ref_graph: StateTransitionGraph, - solutions: List[StateTransitionGraph], + ref_graph: MutableTransition, + solutions: List[MutableTransition], ignored_qn_list: Set[Type[NodeQuantumNumber]], -) -> Optional[StateTransitionGraph]: +) -> Optional[MutableTransition]: """Define equal operator for graphs, ignoring certain quantum numbers.""" - if not isinstance(ref_graph, StateTransitionGraph): - raise TypeError( - "Reference graph has to be of type StateTransitionGraph" - ) + if not isinstance(ref_graph, MutableTransition): + raise TypeError("Reference graph has to be of type MutableTransition") found_graph = None - node_comparator = NodePropertyComparator(ignored_qn_list) + interaction_comparator = NodePropertyComparator(ignored_qn_list) for graph in solutions: - if isinstance(graph, StateTransitionGraph): + if isinstance(graph, MutableTransition): if graph.compare( ref_graph, - edge_comparator=lambda e1, e2: e1 == e2, - node_comparator=node_comparator, + state_comparator=lambda e1, e2: e1 == e2, + interaction_comparator=interaction_comparator, ): found_graph = graph break @@ -276,26 +278,26 @@ def __init__( def __call__( self, - node_props1: InteractionProperties, - node_props2: InteractionProperties, + interactions1: InteractionProperties, + interactions2: InteractionProperties, ) -> bool: return attrs.evolve( - node_props1, + interactions1, **{x.__name__: None for x in self.__ignored_qn_list}, ) == attrs.evolve( - node_props2, + interactions2, **{x.__name__: None for x in self.__ignored_qn_list}, ) def filter_graphs( - graphs: List[StateTransitionGraph], - filters: Iterable[Callable[[StateTransitionGraph], bool]], -) -> List[StateTransitionGraph]: - r"""Implement filtering of a list of `.StateTransitionGraph` 's. + graphs: List[MutableTransition], + filters: Iterable[Callable[[MutableTransition], bool]], +) -> List[MutableTransition]: + r"""Implement filtering of a list of `.MutableTransition` 's. This function can be used to select a subset of - `.StateTransitionGraph` 's from a list. Only the graphs passing + `.MutableTransition` 's from a list. Only the graphs passing all supplied filters will be returned. Note: @@ -326,7 +328,7 @@ def require_interaction_property( ingoing_particle_name: str, interaction_qn: Type[NodeQuantumNumber], allowed_values: List, -) -> Callable[[StateTransitionGraph[ParticleWithSpin]], bool]: +) -> "Callable[[MutableTransition[ParticleWithSpin, InteractionProperties]], bool]": """Filter function. Closure, which can be used as a filter function in :func:`.filter_graphs`. @@ -351,7 +353,9 @@ def require_interaction_property( - *False* otherwise """ - def check(graph: StateTransitionGraph[ParticleWithSpin]) -> bool: + def check( + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", + ) -> bool: node_ids = _find_node_ids_with_ingoing_particle_name( graph, ingoing_particle_name ) @@ -359,7 +363,7 @@ def check(graph: StateTransitionGraph[ParticleWithSpin]) -> bool: return False for i in node_ids: if ( - getattr(graph.get_node_props(i), interaction_qn.__name__) + getattr(graph.interactions[i], interaction_qn.__name__) not in allowed_values ): return False @@ -369,14 +373,15 @@ def check(graph: StateTransitionGraph[ParticleWithSpin]) -> bool: def _find_node_ids_with_ingoing_particle_name( - graph: StateTransitionGraph[ParticleWithSpin], ingoing_particle_name: str + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", + ingoing_particle_name: str, ) -> List[int]: topology = graph.topology found_node_ids = [] for node_id in topology.nodes: for edge_id in topology.get_edge_ids_ingoing_to_node(node_id): - edge_props = graph.get_edge_props(edge_id) - edge_particle_name = edge_props[0].name + states = graph.states[edge_id] + edge_particle_name = states[0].name if str(ingoing_particle_name) in str(edge_particle_name): found_node_ids.append(node_id) break diff --git a/src/qrules/argument_handling.py b/src/qrules/argument_handling.py index 8c200f19..8b1a25f9 100644 --- a/src/qrules/argument_handling.py +++ b/src/qrules/argument_handling.py @@ -93,11 +93,11 @@ def wrapper(props: GraphElementPropertyMap) -> bool: def _sequence_input_check(func: Callable) -> Callable[[Sequence], bool]: - def wrapper(edge_props_list: Sequence[Any]) -> bool: - if not isinstance(edge_props_list, (list, tuple)): + def wrapper(states_list: Sequence[Any]) -> bool: + if not isinstance(states_list, (list, tuple)): raise TypeError("Rule evaluated with invalid argument type...") - return all(func(x) for x in edge_props_list) + return all(func(x) for x in states_list) return wrapper @@ -170,11 +170,11 @@ def __call__( def _sequence_arg_builder(func: Callable) -> Callable[[Sequence], List[Any]]: - def wrapper(edge_props_list: Sequence[Any]) -> List[Any]: - if not isinstance(edge_props_list, (list, tuple)): + def wrapper(states_list: Sequence[Any]) -> List[Any]: + if not isinstance(states_list, (list, tuple)): raise TypeError("Rule evaluated with invalid argument type...") - return [func(x) for x in edge_props_list if x] + return [func(x) for x in states_list if x] return wrapper diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 372ca900..36b298a3 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -1,10 +1,11 @@ -"""Perform permutations on the edges of a `.StateTransitionGraph`. +"""Perform permutations on the edges of a `.MutableTransition`. -In a `.StateTransitionGraph`, the edges represent quantum states, while the -nodes represent interactions. This module provides tools to permutate, modify -or extract these edge and node properties. +In a `.MutableTransition`, the edges represent quantum states, while the nodes +represent interactions. This module provides tools to permutate, modify or +extract these edge and node properties. """ +import sys from collections import OrderedDict from copy import deepcopy from itertools import permutations @@ -24,24 +25,23 @@ Union, ) -from attrs import field, frozen - -from qrules._implementers import implement_pretty_repr from qrules.particle import Particle, ParticleCollection from .particle import ParticleWithSpin from .quantum_numbers import InteractionProperties, arange -from .topology import StateTransitionGraph, Topology, get_originating_node_list +from .topology import MutableTransition, Topology, get_originating_node_list + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias StateWithSpins = Tuple[str, Sequence[float]] StateDefinition = Union[str, StateWithSpins] - - -@implement_pretty_repr -@frozen -class InitialFacts: - edge_props: Dict[int, ParticleWithSpin] = field(factory=dict) - node_props: Dict[int, InteractionProperties] = field(factory=dict) +InitialFacts: TypeAlias = ( + "MutableTransition[ParticleWithSpin, InteractionProperties]" +) +"""A `.Transition` with only initial and final state information.""" class _KinematicRepresentation: @@ -168,7 +168,7 @@ def _get_kinematic_representation( r"""Group final or initial states by node, sorted by length of the group. The resulting sorted groups can be used to check whether two - `.StateTransitionGraph` instances are kinematically identical. For + `.MutableTransition` instances are kinematically identical. For instance, the following two graphs: .. code-block:: @@ -262,13 +262,13 @@ def embed_in_list(some_list: List[Any]) -> List[List[Any]]: final_state=final_state, allowed_kinematic_groupings=allowed_kinematic_groupings, ) - edge_initial_facts = [] + edge_initial_facts: List[InitialFacts] = [] for kinematic_permutation in kinematic_permutation_graphs: spin_permutations = _generate_spin_permutations( kinematic_permutation, particle_db ) edge_initial_facts.extend( - [InitialFacts(edge_props=x) for x in spin_permutations] + [MutableTransition(topology, states=x) for x in spin_permutations] ) return edge_initial_facts @@ -406,19 +406,19 @@ def populate_edge_with_spin_projections( def __get_initial_state_edge_ids( - graph: StateTransitionGraph[ParticleWithSpin], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", ) -> Iterable[int]: return graph.topology.incoming_edge_ids def __get_final_state_edge_ids( - graph: StateTransitionGraph[ParticleWithSpin], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", ) -> Iterable[int]: return graph.topology.outgoing_edge_ids def match_external_edges( - graphs: List[StateTransitionGraph[ParticleWithSpin]], + graphs: "List[MutableTransition[ParticleWithSpin, InteractionProperties]]", ) -> None: if not isinstance(graphs, list): raise TypeError("graphs argument is not of type list") @@ -432,11 +432,9 @@ def match_external_edges( def _match_external_edge_ids( # pylint: disable=too-many-locals - graphs: List[StateTransitionGraph[ParticleWithSpin]], + graphs: "List[MutableTransition[ParticleWithSpin, InteractionProperties]]", ref_graph_id: int, - external_edge_getter_function: Callable[ - [StateTransitionGraph], Iterable[int] - ], + external_edge_getter_function: "Callable[[MutableTransition], Iterable[int]]", ) -> None: ref_graph = graphs[ref_graph_id] # create external edge to particle mapping @@ -471,17 +469,17 @@ def _match_external_edge_ids( # pylint: disable=too-many-locals def perform_external_edge_identical_particle_combinatorics( - graph: StateTransitionGraph, -) -> List[StateTransitionGraph]: - """Create combinatorics clones of the `.StateTransitionGraph`. + graph: MutableTransition, +) -> List[MutableTransition]: + """Create combinatorics clones of the `.MutableTransition`. In case of identical particles in the initial or final state. Only identical particles, which do not enter or exit the same node allow for combinatorics! """ - if not isinstance(graph, StateTransitionGraph): + if not isinstance(graph, MutableTransition): raise TypeError( - f"graph argument is not of type {StateTransitionGraph.__class__}" + f"graph argument is not of type {MutableTransition.__class__}" ) temp_new_graphs = _external_edge_identical_particle_combinatorics( graph, __get_final_state_edge_ids @@ -497,11 +495,11 @@ def perform_external_edge_identical_particle_combinatorics( def _external_edge_identical_particle_combinatorics( - graph: StateTransitionGraph[ParticleWithSpin], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", external_edge_getter_function: Callable[ - [StateTransitionGraph], Iterable[int] + [MutableTransition], Iterable[int] ], -) -> List[StateTransitionGraph]: +) -> List[MutableTransition]: # pylint: disable=too-many-locals new_graphs = [graph] edge_particle_mapping = _create_edge_id_particle_mapping( @@ -557,10 +555,7 @@ def _calculate_swappings(id_mapping: Dict[int, int]) -> OrderedDict: def _create_edge_id_particle_mapping( - graph: StateTransitionGraph[ParticleWithSpin], edge_ids: Iterable[int] + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", + edge_ids: Iterable[int], ) -> Dict[int, str]: - return { - i: graph.get_edge_props(i)[0].name - for i in edge_ids - if graph.get_edge_props(i) - } + return {i: graph.states[i][0].name for i in edge_ids} diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index 60b672f0..aa45d5a0 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -15,8 +15,8 @@ import yaml from qrules.particle import Particle, ParticleCollection -from qrules.topology import StateTransitionGraph, Topology -from qrules.transition import ProblemSet, ReactionInfo, State, StateTransition +from qrules.topology import Topology, Transition +from qrules.transition import ProblemSet, ReactionInfo, State from . import _dict, _dot @@ -27,15 +27,15 @@ def asdict(instance: object) -> dict: return _dict.from_particle(instance) if isinstance(instance, ParticleCollection): return _dict.from_particle_collection(instance) - if isinstance(instance, (ReactionInfo, State, StateTransition)): + if isinstance(instance, (ReactionInfo, State)): return attrs.asdict( instance, recurse=True, filter=lambda a, _: a.init, value_serializer=_dict._value_serializer, ) - if isinstance(instance, StateTransitionGraph): - return _dict.from_stg(instance) + if isinstance(instance, Transition): + return _dict.from_transition(instance) if isinstance(instance, Topology): return _dict.from_topology(instance) raise NotImplementedError( @@ -53,9 +53,7 @@ def fromdict(definition: dict) -> object: if keys == {"transitions", "formalism"}: return _dict.build_reaction_info(definition) if keys == {"topology", "states", "interactions"}: - return _dict.build_state_transition(definition) - if keys == {"topology", "edge_props", "node_props"}: - return _dict.build_stg(definition) + return _dict.build_transition(definition) if keys == __REQUIRED_TOPOLOGY_FIELDS: return _dict.build_topology(definition) raise NotImplementedError(f"Could not determine type from keys {keys}") @@ -87,13 +85,13 @@ def asdot( """Convert a `object` to a DOT language `str`. Only works for objects that can be represented as a graph, particularly a - `.StateTransitionGraph` or a `list` of `.StateTransitionGraph` instances. + `.MutableTransition` or a `list` of `.MutableTransition` instances. Args: instance: the input `object` that is to be rendered as DOT (graphviz) language. - strip_spin: Normally, each `.StateTransitionGraph` has a `.Particle` + strip_spin: Normally, each `.MutableTransition` has a `.Particle` with a spin projection on its edges. This option hides the projections, leaving only `.Particle` names on edges. @@ -102,7 +100,7 @@ def asdot( render_node: Whether or not to render node ID (in the case of a `.Topology`) and/or node properties (in the case of a - `.StateTransitionGraph`). Meaning of the labels: + `.MutableTransition`). Meaning of the labels: - :math:`P`: parity prefactor - :math:`s`: tuple of **coupled spin** magnitude and its @@ -129,9 +127,7 @@ def asdot( edge_style = {} if node_style is None: node_style = {} - if isinstance(instance, StateTransition): - instance = instance.to_graph() - if isinstance(instance, (ProblemSet, StateTransitionGraph, Topology)): + if isinstance(instance, (ProblemSet, Topology, Transition)): dot = _dot.graph_to_dot( instance, render_node=render_node, @@ -143,7 +139,7 @@ def asdot( ) return _dot.insert_graphviz_styling(dot, graphviz_attrs=figure_style) if isinstance(instance, ReactionInfo): - instance = instance.to_graphs() + instance = instance.transitions if isinstance(instance, abc.Iterable): dot = _dot.graph_list_to_dot( instance, diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index f1d81c22..d81b6bd1 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -8,16 +8,10 @@ import attrs -from qrules.particle import ( - Parity, - Particle, - ParticleCollection, - ParticleWithSpin, - Spin, -) +from qrules.particle import Parity, Particle, ParticleCollection, Spin from qrules.quantum_numbers import InteractionProperties -from qrules.topology import Edge, StateTransitionGraph, Topology -from qrules.transition import ReactionInfo, State, StateTransition +from qrules.topology import Edge, FrozenTransition, Topology, Transition +from qrules.transition import ReactionInfo, State def from_particle_collection(particles: ParticleCollection) -> dict: @@ -33,28 +27,13 @@ def from_particle(particle: Particle) -> dict: ) -def from_stg(graph: StateTransitionGraph[ParticleWithSpin]) -> dict: - topology = graph.topology - edge_props_def = {} - for i in topology.edges: - particle, spin_projection = graph.get_edge_props(i) - if isinstance(spin_projection, float) and spin_projection.is_integer(): - spin_projection = int(spin_projection) - edge_props_def[i] = { - "particle": from_particle(particle), - "spin_projection": spin_projection, - } - node_props_def = {} - for i in topology.nodes: - node_prop = graph.get_node_props(i) - node_props_def[i] = attrs.asdict( - node_prop, filter=lambda a, v: a.init and a.default != v - ) - return { - "topology": from_topology(topology), - "edge_props": edge_props_def, - "node_props": node_props_def, - } +def from_transition(graph: Transition) -> dict: + return attrs.asdict( + graph, + recurse=True, + value_serializer=_value_serializer, + filter=lambda attribute, value: attribute.default != value, + ) def from_topology(topology: Topology) -> dict: @@ -73,10 +52,9 @@ def _value_serializer( # pylint: disable=unused-argument if all(map(lambda p: isinstance(p, Particle), value.values())): return {k: v.name for k, v in value.items()} return dict(value) - if not isinstance( - inst, (ReactionInfo, State, StateTransition) - ) and isinstance(value, Particle): - return value.name + if not isinstance(inst, (ReactionInfo, State, FrozenTransition)): + if isinstance(value, Particle): + return value.name if isinstance(value, Parity): return {"value": value.value} if isinstance(value, Spin): @@ -110,62 +88,45 @@ def build_particle(definition: dict) -> Particle: def build_reaction_info(definition: dict) -> ReactionInfo: transitions = [ - build_state_transition(transition_def) + build_transition(transition_def) for transition_def in definition["transitions"] ] return ReactionInfo(transitions, formalism=definition["formalism"]) -def build_stg(definition: dict) -> StateTransitionGraph[ParticleWithSpin]: +def build_transition( + definition: dict, +) -> "FrozenTransition[State, InteractionProperties]": topology = build_topology(definition["topology"]) - edge_props_def: Dict[int, dict] = definition["edge_props"] - edge_props: Dict[int, ParticleWithSpin] = {} - for i, edge_def in edge_props_def.items(): - particle = build_particle(edge_def["particle"]) - spin_projection = float(edge_def["spin_projection"]) - if spin_projection.is_integer(): - spin_projection = int(spin_projection) - edge_props[int(i)] = (particle, spin_projection) - node_props_def: Dict[int, dict] = definition["node_props"] - node_props = { + states_def: Dict[int, dict] = definition["states"] + states: Dict[int, State] = {} + for i, edge_def in states_def.items(): + states[int(i)] = build_state(edge_def) + interactions_def: Dict[int, dict] = definition["interactions"] + interactions = { int(i): InteractionProperties(**node_def) - for i, node_def in node_props_def.items() + for i, node_def in interactions_def.items() } - return StateTransitionGraph( - topology=topology, - edge_props=edge_props, - node_props=node_props, - ) + return FrozenTransition(topology, states, interactions) -def build_state_transition(definition: dict) -> StateTransition: - topology = build_topology(definition["topology"]) - states = { - int(i): State( - particle=build_particle(state_def["particle"]), - spin_projection=float(state_def["spin_projection"]), - ) - for i, state_def in definition["states"].items() - } - interactions = { - int(i): InteractionProperties(**interaction_def) - for i, interaction_def in definition["interactions"].items() - } - return StateTransition( - topology=topology, - states=states, - interactions=interactions, - ) +def build_state(definition: Any) -> State: + if isinstance(definition, (list, tuple)) and len(definition) == 2: + particle = build_particle(definition[0]) + spin_projection = float(definition[1]) + return State(particle, spin_projection) + if isinstance(definition, dict): + particle = build_particle(definition["particle"]) + spin_projection = float(definition["spin_projection"]) + return State(particle, spin_projection) + raise NotImplementedError() def build_topology(definition: dict) -> Topology: nodes = definition["nodes"] edges_def: Dict[int, dict] = definition["edges"] edges = {int(i): Edge(**edge_def) for i, edge_def in edges_def.items()} - return Topology( - edges=edges, - nodes=nodes, - ) + return Topology(nodes, edges) def validate_particle_collection(instance: dict) -> None: diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index c3d57583..0c507a02 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -12,20 +12,25 @@ Dict, Iterable, List, - Mapping, Optional, + Set, Tuple, Union, + cast, ) import attrs -from qrules.combinatorics import InitialFacts -from qrules.particle import Particle, ParticleCollection, ParticleWithSpin +from qrules.particle import Particle, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction -from qrules.solving import EdgeSettings, GraphSettings, NodeSettings -from qrules.topology import StateTransitionGraph, Topology -from qrules.transition import ProblemSet, StateTransition +from qrules.solving import EdgeSettings, NodeSettings +from qrules.topology import ( + FrozenTransition, + MutableTransition, + Topology, + Transition, +) +from qrules.transition import ProblemSet, State, StateTransition _DOT_HEAD = """digraph { rankdir=LR; @@ -138,7 +143,7 @@ def __create_graphviz_assignments(graphviz_attrs: Dict[str, Any]) -> List[str]: @embed_dot def graph_list_to_dot( - graphs: Iterable[StateTransitionGraph], + graphs: Iterable[Transition], *, render_node: bool, render_final_state_id: bool, @@ -161,8 +166,10 @@ def graph_list_to_dot( if render_node: stripped_graphs = [] for graph in graphs: - if isinstance(graph, StateTransition): - graph = graph.to_graph() + if isinstance(graph, FrozenTransition): + graph = graph.convert( + lambda s: (s.particle, s.spin_projection) + ) stripped_graph = _strip_projections(graph) if stripped_graph not in stripped_graphs: stripped_graphs.append(stripped_graph) @@ -188,7 +195,7 @@ def graph_list_to_dot( @embed_dot def graph_to_dot( - graph: StateTransitionGraph, + graph: Transition, *, render_node: bool, render_final_state_id: bool, @@ -209,14 +216,7 @@ def graph_to_dot( def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals,too-many-statements - graph: Union[ - ProblemSet, - StateTransition, - StateTransitionGraph, - Topology, - Tuple[Topology, InitialFacts], - Tuple[Topology, GraphSettings], - ], + graph: Union[ProblemSet, StateTransition, Topology, Transition], prefix: str = "", *, render_node: bool, @@ -229,18 +229,8 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals dot = "" if isinstance(graph, tuple) and len(graph) == 2: topology: Topology = graph[0] - rendered_graph: Union[ - GraphSettings, - InitialFacts, - ProblemSet, - StateTransition, - StateTransitionGraph, - Topology, - ] = graph[1] - elif isinstance(graph, ProblemSet): - rendered_graph = graph - topology = graph.topology - elif isinstance(graph, (StateTransition, StateTransitionGraph)): + rendered_graph: Union[ProblemSet, Topology, Transition] = graph[1] + elif isinstance(graph, (ProblemSet, Transition)): rendered_graph = graph topology = graph.topology elif isinstance(graph, Topology): @@ -279,8 +269,8 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals from_node, to_node, label=label, graphviz_attrs=edge_style ) if isinstance(graph, ProblemSet): - node_props = graph.solving_settings.node_settings - for node_id, settings in node_props.items(): + node_settings = graph.solving_settings.interactions + for node_id, settings in node_settings.items(): node_label = "" if render_node: node_label = __node_label(settings) @@ -289,14 +279,8 @@ def __graph_to_dot_content( # pylint: disable=too-many-branches,too-many-locals label=node_label, graphviz_attrs=node_style, ) - if isinstance(graph, (StateTransition, StateTransitionGraph)): - if isinstance(graph, StateTransition): - interactions: Mapping[ - int, InteractionProperties - ] = graph.interactions - else: - interactions = {i: graph.get_node_props(i) for i in topology.nodes} - for node_id, node_prop in interactions.items(): + if isinstance(graph, Transition): + for node_id, node_prop in graph.interactions.items(): node_label = "" if render_node: node_label = __node_label(node_prop) @@ -332,50 +316,31 @@ def __rank_string(node_edge_ids: Iterable[int], prefix: str = "") -> str: def __get_edge_label( - graph: Union[ - GraphSettings, - InitialFacts, - ProblemSet, - StateTransition, - StateTransitionGraph, - Topology, - ], + graph: Union[ProblemSet, Topology, Transition], edge_id: int, render_edge_id: bool, ) -> str: - if isinstance(graph, GraphSettings): - edge_setting = graph.edge_settings.get(edge_id) - return ___render_edge_with_id(edge_id, edge_setting, render_edge_id) - if isinstance(graph, InitialFacts): - initial_fact = graph.edge_props.get(edge_id) - return ___render_edge_with_id(edge_id, initial_fact, render_edge_id) + if isinstance(graph, Topology): + if render_edge_id: + return str(edge_id) + return "" if isinstance(graph, ProblemSet): - edge_setting = graph.solving_settings.edge_settings.get(edge_id) - initial_fact = graph.initial_facts.edge_props.get(edge_id) + edge_setting = graph.solving_settings.states.get(edge_id) + initial_fact = graph.initial_facts.states.get(edge_id) edge_property: Optional[Union[EdgeSettings, ParticleWithSpin]] = None if edge_setting: edge_property = edge_setting if initial_fact: edge_property = initial_fact return ___render_edge_with_id(edge_id, edge_property, render_edge_id) - if isinstance(graph, StateTransition): - graph = graph.to_graph() - if isinstance(graph, StateTransitionGraph): - edge_prop = graph.get_edge_props(edge_id) - return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) - if isinstance(graph, Topology): - if render_edge_id: - return str(edge_id) - return "" - raise NotImplementedError( - f"Cannot render {graph.__class__.__name__} as dot" - ) + edge_prop = graph.states.get(edge_id) + return ___render_edge_with_id(edge_id, edge_prop, render_edge_id) def ___render_edge_with_id( edge_id: int, edge_prop: Optional[ - Union[EdgeSettings, ParticleCollection, Particle, ParticleWithSpin] + Union[EdgeSettings, Iterable[Particle], Particle, ParticleWithSpin] ], render_edge_id: bool, ) -> str: @@ -391,19 +356,23 @@ def ___render_edge_with_id( def __render_edge_property( edge_prop: Optional[ - Union[EdgeSettings, ParticleCollection, Particle, ParticleWithSpin] + Union[EdgeSettings, Iterable[Particle], Particle, ParticleWithSpin] ] ) -> str: if isinstance(edge_prop, EdgeSettings): return __render_settings(edge_prop) if isinstance(edge_prop, Particle): return edge_prop.name - if isinstance(edge_prop, tuple): + if isinstance(edge_prop, abc.Iterable) and all( + map(lambda i: isinstance(i, Particle), edge_prop) + ): + return "\n".join(map(lambda p: p.name, edge_prop)) + if isinstance(edge_prop, State): + edge_prop = edge_prop.particle, edge_prop.spin_projection + if isinstance(edge_prop, tuple) and len(edge_prop) == 2: particle, spin_projection = edge_prop projection_label = _to_fraction(spin_projection, render_plus=True) return f"{particle.name}[{projection_label}]" - if isinstance(edge_prop, ParticleCollection): - return "\n".join(sorted(edge_prop.names)) raise NotImplementedError @@ -466,120 +435,94 @@ def __extract_priority(description: str) -> int: def _get_particle_graphs( - graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], -) -> List[StateTransitionGraph[Particle]]: - """Strip `list` of `.StateTransitionGraph` s of the spin projections. + graphs: Iterable[Transition[ParticleWithSpin, InteractionProperties]], +) -> "List[FrozenTransition[Particle, None]]": + """Strip `list` of `.Transition` s of the spin projections. - Extract a `list` of `.StateTransitionGraph` instances with only - `.Particle` instances on the edges. + Extract a `list` of `.Transition` instances with only `.Particle` instances + on the edges. .. seealso:: :doc:`/usage/visualize` """ - inventory: List[StateTransitionGraph[Particle]] = [] + inventory = set() for transition in graphs: - if isinstance(transition, StateTransition): - transition = transition.to_graph() - if any( - transition.compare( - other, edge_comparator=lambda e1, e2: e1[0] == e2 + if isinstance(transition, FrozenTransition): + transition = transition.convert( + lambda s: (s.particle, s.spin_projection) ) - for other in inventory - ): - continue - stripped_graph = _strip_projections(transition) - inventory.append(stripped_graph) - inventory = sorted( + stripped_transition = _strip_projections(transition) + topology = stripped_transition.topology + particle_transition: FrozenTransition[ + Particle, None + ] = FrozenTransition( + stripped_transition.topology, + states=stripped_transition.states, + interactions={i: None for i in topology.nodes}, + ) + inventory.add(particle_transition) + return sorted( inventory, key=lambda g: [ - g.get_edge_props(i).mass for i in g.topology.intermediate_edge_ids + g.states[i].mass for i in g.topology.intermediate_edge_ids ], ) - return inventory def _strip_projections( - graph: StateTransitionGraph[ParticleWithSpin], -) -> StateTransitionGraph[Particle]: - if isinstance(graph, StateTransition): - graph = graph.to_graph() - new_edge_props = {} - for edge_id in graph.topology.edges: - edge_props = graph.get_edge_props(edge_id) - if edge_props: - new_edge_props[edge_id] = edge_props[0] - new_node_props = {} - for node_id in graph.topology.nodes: - node_props = graph.get_node_props(node_id) - if node_props: - new_node_props[node_id] = attrs.evolve( - node_props, l_projection=None, s_projection=None - ) - return StateTransitionGraph[Particle]( - topology=graph.topology, - node_props=new_node_props, - edge_props=new_edge_props, + graph: Transition[Any, InteractionProperties], +) -> "FrozenTransition[Particle, InteractionProperties]": + if isinstance(graph, MutableTransition): + transition = graph.freeze() + transition = cast("FrozenTransition[Any, InteractionProperties]", graph) + return transition.convert( + state_converter=__to_particle, + interaction_converter=lambda i: attrs.evolve( + i, l_projection=None, s_projection=None + ), + ) + + +def __to_particle(state: Any) -> Particle: + if isinstance(state, State): + return state.particle + if isinstance(state, tuple) and len(state) == 2: + return state[0] + raise NotImplementedError( + f"Cannot extract a particle from type {type(state).__name__}" ) def _collapse_graphs( - graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], -) -> List[StateTransitionGraph[ParticleCollection]]: - def merge_into( - graph: StateTransitionGraph[Particle], - merged_graph: StateTransitionGraph[ParticleCollection], - ) -> None: - if ( - graph.topology.intermediate_edge_ids - != merged_graph.topology.intermediate_edge_ids - ): - raise ValueError( - "Cannot merge graphs that don't have the same edge IDs" - ) - for i in graph.topology.edges: - particle = graph.get_edge_props(i) - other_particles = merged_graph.get_edge_props(i) - if particle not in other_particles: - other_particles += particle - - def is_same_shape( - graph: StateTransitionGraph[Particle], - merged_graph: StateTransitionGraph[ParticleCollection], - ) -> bool: - if graph.topology.edges != merged_graph.topology.edges: - return False - for edge_id in ( - graph.topology.incoming_edge_ids | graph.topology.outgoing_edge_ids - ): - edge_prop = merged_graph.get_edge_props(edge_id) - if len(edge_prop) != 1: - return False - other_particle = next(iter(edge_prop)) - if other_particle != graph.get_edge_props(edge_id): - return False - return True - - particle_graphs = _get_particle_graphs(graphs) - inventory: List[StateTransitionGraph[ParticleCollection]] = [] - for graph in particle_graphs: - append_to_inventory = True - for merged_graph in inventory: - if is_same_shape(graph, merged_graph): - merge_into(graph, merged_graph) - append_to_inventory = False - break - if append_to_inventory: - new_edge_props = { - edge_id: ParticleCollection({graph.get_edge_props(edge_id)}) - for edge_id in graph.topology.edges - } - inventory.append( - StateTransitionGraph[ParticleCollection]( - topology=graph.topology, - node_props={ - i: graph.get_node_props(i) - for i in graph.topology.nodes - }, - edge_props=new_edge_props, - ) + graphs: Iterable[Transition[ParticleWithSpin, InteractionProperties]], +) -> "Tuple[FrozenTransition[Tuple[Particle, ...], None], ...]": + transition_groups: "Dict[Topology, MutableTransition[Set[Particle], None]]" = { + g.topology: MutableTransition( + g.topology, + states={i: set() for i in g.topology.edges}, + interactions={i: None for i in g.topology.nodes}, + ) + for g in graphs + } + for transition in graphs: + topology = transition.topology + group = transition_groups[topology] + for state_id, state in transition.states.items(): + if isinstance(state, State): + particle = state.particle + else: + particle, _ = state + group.states[state_id].add(particle) + collected_graphs: "List[FrozenTransition[Tuple[Particle, ...], None]]" = [] + for topology in sorted(transition_groups): + group = transition_groups[topology] + collected_graphs.append( + FrozenTransition( + topology, + states={ + i: tuple(sorted(particles, key=lambda p: p.name)) + for i, particles in group.states.items() + }, + interactions=group.interactions, ) - return inventory + ) + return tuple(collected_graphs) diff --git a/src/qrules/particle.py b/src/qrules/particle.py index 6f2c1ae8..f2140a26 100644 --- a/src/qrules/particle.py +++ b/src/qrules/particle.py @@ -8,7 +8,7 @@ :doc:`/usage/particle`). The `.transition` module uses the properties of `Particle` instances when it -computes which `.StateTransitionGraph` s are allowed between an initial state +computes which `.MutableTransition` s are allowed between an initial state and final state. """ diff --git a/src/qrules/quantum_numbers.py b/src/qrules/quantum_numbers.py index e76c48f5..4385fd2c 100644 --- a/src/qrules/quantum_numbers.py +++ b/src/qrules/quantum_numbers.py @@ -169,7 +169,7 @@ def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: class InteractionProperties: """Immutable data structure containing interaction properties. - Interactions are represented by a node on a `.StateTransitionGraph`. This + Interactions are represented by a node on a `.MutableTransition`. This class represents the properties that are carried collectively by the edges that this node connects. diff --git a/src/qrules/solving.py b/src/qrules/solving.py index 9502bb51..35d045bb 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -2,15 +2,16 @@ """Functions to solve a particle reaction problem. This module is responsible for solving a particle reaction problem stated by a -`.StateTransitionGraph` and corresponding `.GraphSettings`. The `.Solver` -classes (e.g. :class:`.CSPSolver`) generate new quantum numbers (for example -belonging to an intermediate state) and validate the decay processes with the -rules formulated by the :mod:`.conservation_rules` module. +`.QNProblemSet`. The `.Solver` classes (e.g. :class:`.CSPSolver`) generate new +quantum numbers (for example belonging to an intermediate state) and validate +the decay processes with the rules formulated by the :mod:`.conservation_rules` +module. """ import inspect import logging +import sys from abc import ABC, abstractmethod from collections import defaultdict from copy import copy @@ -55,7 +56,12 @@ EdgeQuantumNumbers, NodeQuantumNumber, ) -from .topology import Topology +from .topology import MutableTransition, Topology + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias @implement_pretty_repr @@ -89,18 +95,12 @@ class NodeSettings: interaction_strength: float = 1.0 -@implement_pretty_repr -@define -class GraphSettings: - edge_settings: Dict[int, EdgeSettings] = field(factory=dict) - node_settings: Dict[int, NodeSettings] = field(factory=dict) - - -@implement_pretty_repr -@define -class GraphElementProperties: - edge_props: Dict[int, GraphEdgePropertyMap] = field(factory=dict) - node_props: Dict[int, GraphNodePropertyMap] = field(factory=dict) +GraphSettings: TypeAlias = "MutableTransition[EdgeSettings, NodeSettings]" +"""(Mutable) mapping of settings on a `.Topology`.""" +GraphElementProperties: TypeAlias = ( + "MutableTransition[GraphEdgePropertyMap, GraphNodePropertyMap]" +) +"""(Mutable) mapping of edge and node properties on a `.Topology`.""" @implement_pretty_repr @@ -109,25 +109,22 @@ class QNProblemSet: """Particle reaction problem set, defined as a graph like data structure. Args: - topology (`.Topology`): a topology that represent the structure of the - reaction - initial_facts (`.GraphElementProperties`): all of the known facts quantum - numbers of the problem - solving_settings (`.GraphSettings`): solving specific settings such as - the specific rules and variable domains for nodes and edges of the - topology + initial_facts: all of the known facts quantum numbers of the problem. + solving_settings: solving specific settings, such as the specific rules + and variable domains for nodes and edges of the :attr:`topology`. """ - topology: Topology - initial_facts: GraphElementProperties - solving_settings: GraphSettings + initial_facts: "GraphElementProperties" + solving_settings: "GraphSettings" + @property + def topology(self) -> Topology: + return self.initial_facts.topology -@implement_pretty_repr -@frozen -class QuantumNumberSolution: - node_quantum_numbers: Dict[int, GraphNodePropertyMap] - edge_quantum_numbers: Dict[int, GraphEdgePropertyMap] + +QuantumNumberSolution: TypeAlias = ( + "MutableTransition[GraphEdgePropertyMap, GraphNodePropertyMap]" +) def _convert_violated_rules_to_names( @@ -259,26 +256,26 @@ def _merge_particle_candidates_with_solutions( current_new_solutions = [solution] for int_edge_id in intermediate_edges: particle_edges = __get_particle_candidates_for_state( - solution.edge_quantum_numbers[int_edge_id], + solution.states[int_edge_id], allowed_particles, ) if len(particle_edges) == 0: logging.debug("Did not find any particle candidates for") logging.debug("edge id: %d", int_edge_id) logging.debug("edge properties:") - logging.debug(solution.edge_quantum_numbers[int_edge_id]) + logging.debug(solution.states[int_edge_id]) new_solutions_temp = [] for current_new_solution in current_new_solutions: for particle_edge in particle_edges: # a "shallow" copy of the nested dicts is needed new_edge_qns = { k: copy(v) - for k, v in current_new_solution.edge_quantum_numbers.items() + for k, v in current_new_solution.states.items() } new_edge_qns[int_edge_id].update(particle_edge) temp_solution = attrs.evolve( current_new_solution, - edge_quantum_numbers=new_edge_qns, + states=new_edge_qns, ) new_solutions_temp.append(temp_solution) current_new_solutions = new_solutions_temp @@ -326,12 +323,12 @@ def _create_node_variables( ) -> Dict[Type[NodeQuantumNumber], Scalar]: """Create variables for the quantum numbers of the specified node.""" variables = {} - if node_id in problem_set.initial_facts.node_props: - node_props = problem_set.initial_facts.node_props[node_id] - variables = node_props + if node_id in problem_set.initial_facts.interactions: + interactions = problem_set.initial_facts.interactions[node_id] + variables = interactions for qn_type in qn_list: - if qn_type in node_props: - variables[qn_type] = node_props[qn_type] + if qn_type in interactions: + variables[qn_type] = interactions[qn_type] return variables def _create_edge_variables( @@ -346,22 +343,21 @@ def _create_edge_variables( """ variables = [] for edge_id in edge_ids: - if edge_id in problem_set.initial_facts.edge_props: - edge_props = problem_set.initial_facts.edge_props[edge_id] + if edge_id in problem_set.initial_facts.states: + states = problem_set.initial_facts.states[edge_id] edge_vars = {} for qn_type in qn_list: - if qn_type in edge_props: - edge_vars[qn_type] = edge_props[qn_type] + if qn_type in states: + edge_vars[qn_type] = states[qn_type] variables.append(edge_vars) return variables def _create_variable_containers( node_id: int, cons_law: Rule ) -> Tuple[List[dict], List[dict], dict]: - in_edges = problem_set.topology.get_edge_ids_ingoing_to_node(node_id) - out_edges = problem_set.topology.get_edge_ids_outgoing_from_node( - node_id - ) + topology = problem_set.topology + in_edges = topology.get_edge_ids_ingoing_to_node(node_id) + out_edges = topology.get_edge_ids_outgoing_from_node(node_id) edge_qns, node_qns = get_required_qns(cons_law) in_edges_vars = _create_edge_variables(in_edges, edge_qns) @@ -380,7 +376,7 @@ def _create_variable_containers( for ( edge_id, edge_settings, - ) in problem_set.solving_settings.edge_settings.items(): + ) in problem_set.solving_settings.states.items(): edge_rules = edge_settings.conservation_rules for edge_rule in edge_rules: # get the needed qns for this conservation law @@ -407,7 +403,7 @@ def _create_variable_containers( for ( node_id, node_settings, - ) in problem_set.solving_settings.node_settings.items(): + ) in problem_set.solving_settings.interactions.items(): node_rules = node_settings.conservation_rules for rule in node_rules: # get the needed qns for this conservation law @@ -443,9 +439,10 @@ def _create_variable_containers( ) return QNResult( [ - QuantumNumberSolution( - edge_quantum_numbers=problem_set.initial_facts.edge_props, - node_quantum_numbers=problem_set.initial_facts.node_props, + MutableTransition( + topology=problem_set.topology, + states=problem_set.initial_facts.states, + interactions=problem_set.initial_facts.interactions, ) ], ) @@ -534,7 +531,9 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: elif self.__scoresheet.rule_passes[(edge_id, rule)] == 0: edge_not_satisfied_rules[edge_id].add(rule) - solutions = self.__convert_solution_keys(solutions) + solutions = self.__convert_solution_keys( + problem_set.topology, solutions + ) # insert particle instances if self.__node_rules or self.__edge_rules: @@ -548,8 +547,9 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: else: full_particle_solutions = [ QuantumNumberSolution( - node_quantum_numbers=problem_set.initial_facts.node_props, - edge_quantum_numbers=problem_set.initial_facts.edge_props, + topology=problem_set.topology, + interactions=problem_set.initial_facts.interactions, + states=problem_set.initial_facts.states, ) ] @@ -558,26 +558,26 @@ def find_solutions(self, problem_set: QNProblemSet) -> QNResult: ): # rerun solver on these graphs using not executed rules # and combine results + topology = problem_set.topology result = QNResult() for full_particle_solution in full_particle_solutions: - node_props = full_particle_solution.node_quantum_numbers - edge_props = full_particle_solution.edge_quantum_numbers - node_props.update(problem_set.initial_facts.node_props) - edge_props.update(problem_set.initial_facts.edge_props) + interactions = full_particle_solution.interactions + states = full_particle_solution.states + interactions.update(problem_set.initial_facts.interactions) + states.update(problem_set.initial_facts.states) result.extend( validate_full_solution( QNProblemSet( - topology=problem_set.topology, - initial_facts=GraphElementProperties( - node_props=node_props, - edge_props=edge_props, + initial_facts=MutableTransition( + topology, states, interactions ), - solving_settings=GraphSettings( - node_settings={ + solving_settings=MutableTransition( + topology, + interactions={ i: NodeSettings(conservation_rules=rules) for i, rules in node_not_executed_rules.items() }, - edge_settings={ + states={ i: EdgeSettings(conservation_rules=rules) for i, rules in edge_not_executed_rules.items() }, @@ -639,7 +639,7 @@ def get_rules_by_priority( arg_handler = RuleArgumentHandler() for edge_id in problem_set.topology.edges: - edge_settings = problem_set.solving_settings.edge_settings[edge_id] + edge_settings = problem_set.solving_settings.states[edge_id] for rule in get_rules_by_priority(edge_settings): variable_mapping = _VariableContainer() # from cons law and graph determine needed var lists @@ -673,7 +673,7 @@ def get_rules_by_priority( for node_id in problem_set.topology.nodes: for rule in get_rules_by_priority( - problem_set.solving_settings.node_settings[node_id] + problem_set.solving_settings.interactions[node_id] ): variable_mapping = _VariableContainer() # from cons law and graph determine needed var lists @@ -755,13 +755,13 @@ def __create_node_variables( {}, ) - if node_id in problem_set.initial_facts.node_props: - node_props = problem_set.initial_facts.node_props[node_id] + if node_id in problem_set.initial_facts.interactions: + interactions = problem_set.initial_facts.interactions[node_id] for qn_type in qn_list: - if qn_type in node_props: - variables[1].update({qn_type: node_props[qn_type]}) + if qn_type in interactions: + variables[1].update({qn_type: interactions[qn_type]}) else: - node_settings = problem_set.solving_settings.node_settings[node_id] + node_settings = problem_set.solving_settings.interactions[node_id] for qn_type in qn_list: var_info = (node_id, qn_type) if qn_type in node_settings.qn_domains: @@ -793,17 +793,15 @@ def __create_edge_variables( for edge_id in edge_ids: variables[1][edge_id] = {} - if edge_id in problem_set.initial_facts.edge_props: - edge_props = problem_set.initial_facts.edge_props[edge_id] + if edge_id in problem_set.initial_facts.states: + states = problem_set.initial_facts.states[edge_id] for qn_type in qn_list: - if qn_type in edge_props: + if qn_type in states: variables[1][edge_id].update( - {qn_type: edge_props[qn_type]} + {qn_type: states[qn_type]} ) else: - edge_settings = problem_set.solving_settings.edge_settings[ - edge_id - ] + edge_settings = problem_set.solving_settings.states[edge_id] for qn_type in qn_list: var_info = (edge_id, qn_type) if qn_type in edge_settings.qn_domains: @@ -824,33 +822,25 @@ def __add_variable( self.__problem.addVariable(var_string, domain) def __convert_solution_keys( - self, - solutions: List[Dict[str, Scalar]], + self, topology: Topology, solutions: List[Dict[str, Scalar]] ) -> List[QuantumNumberSolution]: """Convert keys of CSP solutions from `str` to quantum number types.""" converted_solutions = [] for solution in solutions: - edge_quantum_numbers: Dict[ - int, GraphEdgePropertyMap - ] = defaultdict(dict) - node_quantum_numbers: Dict[ - int, GraphNodePropertyMap - ] = defaultdict(dict) + states: Dict[int, GraphEdgePropertyMap] = defaultdict(dict) + interactions: Dict[int, GraphNodePropertyMap] = defaultdict(dict) for var_string, value in solution.items(): ele_id, qn_type = self.__var_string_to_data[var_string] if qn_type in getattr( # noqa: B009 EdgeQuantumNumber, "__args__" ): - edge_quantum_numbers[ele_id].update({qn_type: value}) # type: ignore[dict-item] + states[ele_id].update({qn_type: value}) # type: ignore[dict-item] else: - node_quantum_numbers[ele_id].update({qn_type: value}) # type: ignore[dict-item] + interactions[ele_id].update({qn_type: value}) # type: ignore[dict-item] converted_solutions.append( - QuantumNumberSolution( - node_quantum_numbers, edge_quantum_numbers - ) + MutableTransition(topology, states, interactions) ) - return converted_solutions diff --git a/src/qrules/topology.py b/src/qrules/topology.py index b8e6b769..18fe4c65 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -1,25 +1,32 @@ -"""All modules related to topology building. +# pylint: disable=too-many-lines +"""Functionality for `Topology` and `Transition` instances. -Responsible for building all possible topologies bases on basic user input: +.. rubric:: Main interfaces -- number of initial state particles -- number of final state particles +- `Topology` and its builder functions :func:`create_isobar_topologies` and + :func:`create_n_body_topology`. +- `Transition` and its two implementations `MutableTransition` and + `FrozenTransition`. -The main interface is the `.StateTransitionGraph`. +.. autolink-preface:: + + from qrules.topology import ( + create_isobar_topologies, + create_n_body_topology, + ) """ import copy import itertools import logging import sys -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import abc from functools import total_ordering from typing import ( TYPE_CHECKING, Any, Callable, - Collection, Dict, FrozenSet, Generic, @@ -35,16 +42,15 @@ Tuple, TypeVar, ValuesView, + overload, ) import attrs from attrs import define, field, frozen -from attrs.validators import instance_of +from attrs.validators import deep_iterable, deep_mapping, instance_of from qrules._implementers import implement_pretty_repr -from .quantum_numbers import InteractionProperties - if sys.version_info >= (3, 8): from typing import Protocol else: @@ -57,24 +63,35 @@ PrettyPrinter = Any -class Comparable(Protocol): +class _Comparable(Protocol): @abstractmethod def __lt__(self, other: Any) -> bool: ... -KeyType = TypeVar("KeyType", bound=Comparable) -"""Type the keys of the `~typing.Mapping`, see `~typing.KeysView`.""" -ValueType = TypeVar("ValueType") -"""Type the value of the `~typing.Mapping`, see `~typing.ValuesView`.""" +KT = TypeVar("KT", bound=_Comparable) +VT = TypeVar("VT") @total_ordering class FrozenDict( # pylint: disable=too-many-ancestors - abc.Hashable, abc.Mapping, Generic[KeyType, ValueType] + abc.Hashable, abc.Mapping, Generic[KT, VT] ): + """An **immutable** and **hashable** version of a `dict`. + + `FrozenDict` makes it possible to make classes hashable if they are + decorated with :func:`attr.frozen` and contain `~typing.Mapping`-like + attributes. If these attributes were to be implemented with a normal + `dict`, the instance is strictly speaking still mutable (even if those + attributes are a `property`) and the class is therefore not safely + hashable. + + .. warning:: The keys have to be comparable, that is, they need to have a + :meth:`~object.__lt__` method. + """ + def __init__(self, mapping: Optional[Mapping] = None): - self.__mapping: Dict[KeyType, ValueType] = {} + self.__mapping: Dict[KT, VT] = {} if mapping is not None: self.__mapping = dict(mapping) self.__hash = hash(None) @@ -100,13 +117,13 @@ def _repr_pretty_(self, p: "PrettyPrinter", cycle: bool) -> None: p.breakable() p.text("})") - def __iter__(self) -> Iterator[KeyType]: + def __iter__(self) -> Iterator[KT]: return iter(self.__mapping) def __len__(self) -> int: return len(self.__mapping) - def __getitem__(self, key: KeyType) -> ValueType: + def __getitem__(self, key: KT) -> VT: return self.__mapping[key] def __gt__(self, other: Any) -> bool: @@ -123,19 +140,19 @@ def __gt__(self, other: Any) -> bool: def __hash__(self) -> int: return self.__hash - def keys(self) -> KeysView[KeyType]: + def keys(self) -> KeysView[KT]: return self.__mapping.keys() - def items(self) -> ItemsView[KeyType, ValueType]: + def items(self) -> ItemsView[KT, VT]: return self.__mapping.items() - def values(self) -> ValuesView[ValueType]: + def values(self) -> ValuesView[VT]: return self.__mapping.values() def _convert_mapping_to_sorted_tuple( - mapping: Mapping[KeyType, ValueType], -) -> Tuple[Tuple[KeyType, ValueType], ...]: + mapping: Mapping[KT, VT], +) -> Tuple[Tuple[KT, VT], ...]: return tuple((key, mapping[key]) for key in sorted(mapping.keys())) @@ -147,73 +164,133 @@ def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: @frozen(order=True) class Edge: - """Struct-like definition of an edge, used in `Topology`.""" + """Struct-like definition of an edge, used in `Topology.edges`.""" originating_node_id: Optional[int] = field( default=None, converter=_to_optional_int ) + """Node ID where the `Edge` **starts**. + + An `Edge` is **incoming to** a `Topology` if its `originating_node_id` is + `None` (see `~Topology.incoming_edge_ids`). + """ ending_node_id: Optional[int] = field( default=None, converter=_to_optional_int ) + """Node ID where the `Edge` **ends**. + + An `Edge` is **outgoing from** a `Topology` if its `ending_node_id` is + `None` (see `~Topology.outgoing_edge_ids`). + """ def get_connected_nodes(self) -> Set[int]: + """Get all node IDs to which the `Edge` is connected.""" connected_nodes = {self.ending_node_id, self.originating_node_id} connected_nodes.discard(None) return connected_nodes # type: ignore[return-value] -def _to_frozenset(iterable: Iterable[int]) -> FrozenSet[int]: - if not all(map(lambda i: isinstance(i, int), iterable)): - raise TypeError(f"Not all items in {iterable} are of type int") - return frozenset(iterable) +def _to_topology_nodes(inst: Iterable[int]) -> FrozenSet[int]: + return frozenset(inst) + + +def _to_topology_edges(inst: Mapping[int, Edge]) -> FrozenDict[int, Edge]: + return FrozenDict(inst) @implement_pretty_repr @frozen(order=True) class Topology: + # noqa: D416 """Directed Feynman-like graph without edge or node properties. - Forms the underlying topology of `StateTransitionGraph`. The graphs are - directed, meaning the edges are ingoing and outgoing to specific nodes - (since feynman graphs also have a time axis). Note that a `Topology` is not - strictly speaking a graph from graph theory, because it allows open edges, - like a Feynman-diagram. + A `Topology` is **directed** in the sense that its edges are ingoing and + outgoing to specific nodes. This is to mimic Feynman graphs, which have a + time axis. Note that a `Topology` is not strictly speaking a graph from + graph theory, because it allows open edges, like a Feynman-diagram. + + The edges and nodes can be provided with properties with a `Transition`, + which contains a `~Transition.topology`. + + As opposed to a `MutableTopology`, a `Topology` is frozen, hashable, and + ordered, so that it can be used as a kind of fingerprint for a + `Transition`. In addition, the IDs of `edges` are guaranteed to be + sequential integers and follow a specific pattern: + + - `incoming_edge_ids` (`~Transition.initial_states`) are always negative. + - `outgoing_edge_ids` (`~Transition.final_states`) lie in the range + :code:`0...n-1` with :code:`n` the number of final states. + - `intermediate_edge_ids` continue counting from :code:`n`. + + See also :meth:`MutableTopology.organize_edge_ids`. + + Example + ------- + **Isobar decay** topologies can best be created as follows: + + >>> topologies = create_isobar_topologies(number_of_final_states=3) + >>> len(topologies) + 1 + >>> topologies[0] + Topology(nodes=..., edges=...) """ - nodes: FrozenSet[int] = field(converter=_to_frozenset) - edges: FrozenDict[int, Edge] = field(converter=FrozenDict) + nodes: FrozenSet[int] = field( + converter=_to_topology_nodes, + validator=deep_iterable(member_validator=instance_of(int)), + ) + """A node is a point where different `edges` connect.""" + edges: FrozenDict[int, Edge] = field( + converter=_to_topology_edges, + validator=deep_mapping( + key_validator=instance_of(int), value_validator=instance_of(Edge) + ), + ) + """Mapping of edge IDs to their corresponding `Edge` definition.""" incoming_edge_ids: FrozenSet[int] = field(init=False, repr=False) + """Edge IDs of edges that have no `~Edge.originating_node_id`. + + `Transition.initial_states` provide properties for these edges. + """ outgoing_edge_ids: FrozenSet[int] = field(init=False, repr=False) + """Edge IDs of edges that have no `~Edge.ending_node_id`. + + `Transition.final_states` provide properties for these edges. + """ intermediate_edge_ids: FrozenSet[int] = field(init=False, repr=False) + """Edge IDs of edges that connect two `nodes`.""" def __attrs_post_init__(self) -> None: self.__verify() - object.__setattr__( - self, - "incoming_edge_ids", - frozenset( - edge_id - for edge_id, edge in self.edges.items() - if edge.originating_node_id is None - ), - ) - object.__setattr__( - self, - "outgoing_edge_ids", - frozenset( - edge_id - for edge_id, edge in self.edges.items() - if edge.ending_node_id is None - ), + incoming = sorted( + edge_id + for edge_id, edge in self.edges.items() + if edge.originating_node_id is None ) - object.__setattr__( - self, - "intermediate_edge_ids", - frozenset(self.edges) - ^ self.incoming_edge_ids - ^ self.outgoing_edge_ids, + outgoing = sorted( + edge_id + for edge_id, edge in self.edges.items() + if edge.ending_node_id is None ) + inter = sorted(set(self.edges) - set(incoming) - set(outgoing)) + expected = list(range(-len(incoming), 0)) + if sorted(incoming) != expected: + raise ValueError( + f"Incoming edge IDs should be {expected}, not {incoming}." + ) + n_out = len(outgoing) + expected = list(range(0, n_out)) + if sorted(outgoing) != expected: + raise ValueError( + f"Outgoing edge IDs should be {expected}, not {outgoing}." + ) + expected = list(range(n_out, n_out + len(inter))) + if sorted(inter) != expected: + raise ValueError(f"Intermediate edge IDs should be {expected}.") + object.__setattr__(self, "incoming_edge_ids", frozenset(incoming)) + object.__setattr__(self, "outgoing_edge_ids", frozenset(outgoing)) + object.__setattr__(self, "intermediate_edge_ids", frozenset(inter)) def __verify(self) -> None: """Verify if there are no dangling edges or nodes.""" @@ -255,6 +332,8 @@ def is_isomorphic(self, other: "Topology") -> bool: Returns `True` if the two graphs have a one-to-one mapping of the node IDs and edge IDs. + + .. warning:: Not yet implemented. """ raise NotImplementedError @@ -308,29 +387,7 @@ def get_originating_initial_state_edge_ids(self, node_id: int) -> Set[int]: temp_edge_list = new_temp_edge_list return edge_ids - def organize_edge_ids(self) -> "Topology": - """Create a new topology with edge IDs in range :code:`[-m, n+i]`. - - where :code:`m` is the number of `.incoming_edge_ids`, :code:`n` is the - number of `.outgoing_edge_ids`, and :code:`i` is the number of - `.intermediate_edge_ids`. - - In other words, relabel the edges so that: - - - `.incoming_edge_ids` lies in the range :code:`[-1, -2, ...]` - - `.outgoing_edge_ids` lies in the range :code:`[0, 1, ..., n]` - - `.intermediate_edge_ids` lies in the range :code:`[n+1, n+2, ...]` - """ - new_to_old_id = enumerate( - tuple(self.incoming_edge_ids) - + tuple(self.outgoing_edge_ids) - + tuple(self.intermediate_edge_ids), - start=-len(self.incoming_edge_ids), - ) - old_to_new_id = {j: i for i, j in new_to_old_id} - return self.relabel_edges(old_to_new_id) - - def relabel_edges(self, old_to_new_id: Mapping[int, int]) -> "Topology": + def relabel_edges(self, old_to_new: Mapping[int, int]) -> "Topology": """Create a new `Topology` with new edge IDs. This method is particularly useful when creating permutations of a @@ -348,8 +405,10 @@ def relabel_edges(self, old_to_new_id: Mapping[int, int]) -> "Topology": >>> len(permuted_topologies) 3 """ + new_to_old = {j: i for i, j in old_to_new.items()} new_edges = { - old_to_new_id.get(i, i): edge for i, edge in self.edges.items() + old_to_new.get(i, new_to_old.get(i, i)): edge + for i, edge in self.edges.items() } return attrs.evolve(self, edges=new_edges) @@ -376,29 +435,54 @@ def __get_originating_node(edge_id: int) -> Optional[int]: ] -@define(kw_only=True) -class _MutableTopology: - edges: Dict[int, Edge] = field(factory=dict, converter=dict) - nodes: Set[int] = field(factory=set, converter=set) +def _to_mutable_topology_nodes(inst: Iterable[int]) -> Set[int]: + return set(inst) - def freeze(self) -> Topology: - return Topology( - edges=self.edges, - nodes=self.nodes, - ) + +def _to_mutable_topology_edges(inst: Mapping[int, Edge]) -> Dict[int, Edge]: + return dict(inst) + + +@define +class MutableTopology: + """Mutable version of a `Topology`. + + A `MutableTopology` can be used to conveniently build up a `Topology` (see + e.g. `SimpleStateTransitionTopologyBuilder`). It does not have restrictions + on the numbering of edge and node IDs. + """ + + nodes: Set[int] = field( + converter=_to_mutable_topology_nodes, + factory=set, + on_setattr=deep_iterable(member_validator=instance_of(int)), + ) + """See `Topology.nodes`.""" + edges: Dict[int, Edge] = field( + converter=_to_mutable_topology_edges, + factory=dict, + on_setattr=deep_mapping( + key_validator=instance_of(int), value_validator=instance_of(Edge) + ), + ) + """See `Topology.edges`.""" def add_node(self, node_id: int) -> None: - """Adds a node nr. node_id. + """Adds a node with number :code:`node_id`. Raises: - ValueError: if node_id already exists + ValueError: if :code:`node_id` already exists in `nodes`. """ if node_id in self.nodes: raise ValueError(f"Node nr. {node_id} already exists") self.nodes.add(node_id) - def add_edges(self, edge_ids: List[int]) -> None: - """Add edges with the ids in the edge_ids list.""" + def add_edges(self, edge_ids: Iterable[int]) -> None: + """Add edges with the ids in the :code:`edge_ids` list. + + Raises: + ValueError: if :code:`edge_ids` already exist in `edges`. + """ for edge_id in edge_ids: if edge_id in self.edges: raise ValueError(f"Edge nr. {edge_id} already exists") @@ -458,6 +542,47 @@ def attach_edges_to_node_outgoing( originating_node_id=node_id, ) + def organize_edge_ids(self) -> "MutableTopology": + """Organize edge IDS so that they lie in range :code:`[-m, n+i]`. + + Here, :code:`m` is the number of `.incoming_edge_ids`, :code:`n` is the + number of `.outgoing_edge_ids`, and :code:`i` is the number of + `.intermediate_edge_ids`. + + In other words, relabel the edges so that: + + - incoming edge IDs lie in the range :code:`[-1, -2, ...]`, + - outgoing edge IDs lie in the range :code:`[0, 1, ..., n]`, + - intermediate edge IDs lie in the range :code:`[n+1, n+2, ...]`. + """ + incoming = { + i + for i, edge in self.edges.items() + if edge.originating_node_id is None + } + outgoing = { + edge_id + for edge_id, edge in self.edges.items() + if edge.ending_node_id is None + } + intermediate = set(self.edges) - incoming - outgoing + new_to_old_id = enumerate( + list(incoming) + list(outgoing) + list(intermediate), + start=-len(incoming), + ) + old_to_new_id = {j: i for i, j in new_to_old_id} + new_edges = { + old_to_new_id.get(i, i): edge for i, edge in self.edges.items() + } + return attrs.evolve(self, edges=new_edges) + + def freeze(self) -> Topology: + """Create an immutable `Topology` from this `MutableTopology`. + + You may need to call :meth:`organize_edge_ids` first. + """ + return Topology(self.nodes, self.edges) + @define class InteractionNode: @@ -506,14 +631,12 @@ def build( logging.info("building topology graphs...") # result list - graph_tuple_list: List[Tuple[_MutableTopology, List[int]]] = [] + graph_tuple_list: List[Tuple[MutableTopology, List[int]]] = [] # create seed graph - seed_graph = _MutableTopology() + seed_graph = MutableTopology() current_open_end_edges = list(range(number_of_initial_edges)) seed_graph.add_edges(current_open_end_edges) - extendable_graph_list: List[Tuple[_MutableTopology, List[int]]] = [ - (seed_graph, current_open_end_edges) - ] + extendable_graph_list = [(seed_graph, current_open_end_edges)] while extendable_graph_list: active_graph_list = extendable_graph_list @@ -524,7 +647,6 @@ def build( len(active_graph[1]) == number_of_final_edges and len(active_graph[0].nodes) > 0 ): - active_graph[0].freeze() # verify graph_tuple_list.append(active_graph) continue @@ -534,15 +656,15 @@ def build( # strip the current open end edges list from the result graph tuples topologies = [] for graph_tuple in graph_tuple_list: - topology = graph_tuple[0].freeze() + topology = graph_tuple[0] topology = topology.organize_edge_ids() - topologies.append(topology) + topologies.append(topology.freeze()) return tuple(topologies) def _extend_graph( - self, pair: Tuple[_MutableTopology, Sequence[int]] - ) -> List[Tuple[_MutableTopology, List[int]]]: - extended_graph_list: List[Tuple[_MutableTopology, List[int]]] = [] + self, pair: Tuple[MutableTopology, Sequence[int]] + ) -> List[Tuple[MutableTopology, List[int]]]: + extended_graph_list: List[Tuple[MutableTopology, List[int]]] = [] topology, current_open_end_edges = pair @@ -580,6 +702,27 @@ def _extend_graph( def create_isobar_topologies( number_of_final_states: int, ) -> Tuple[Topology, ...]: + """Builder function to create a set of unique isobar decay topologies. + + Args: + number_of_final_states: The number of `~Topology.outgoing_edge_ids` + (`~.Transition.final_states`). + + Returns: + A sorted `tuple` of non-isomorphic `Topology` instances, all with the + same number of final states. + + Example: + >>> topologies = create_isobar_topologies(number_of_final_states=4) + >>> len(topologies) + 2 + >>> len(topologies[0].outgoing_edge_ids) + 4 + >>> len(set(topologies)) # hashable + 2 + >>> list(topologies) == sorted(topologies) # ordered + True + """ if number_of_final_states < 2: raise ValueError( "At least two final states required for an isobar decay" @@ -589,12 +732,37 @@ def create_isobar_topologies( number_of_initial_edges=1, number_of_final_edges=number_of_final_states, ) - return tuple(topologies) + return tuple(sorted(topologies)) def create_n_body_topology( number_of_initial_states: int, number_of_final_states: int ) -> Topology: + """Create a `Topology` that connects all edges through a single node. + + These types of ":math:`n`-body topologies" are particularly important for + :func:`.check_reaction_violations` and :mod:`.conservation_rules`. + + Args: + number_of_initial_states: The number of `~Topology.incoming_edge_ids` + (`~.Transition.initial_states`). + number_of_final_states: The number of `~Topology.outgoing_edge_ids` + (`~.Transition.final_states`). + + Example: + >>> topology = create_n_body_topology( + ... number_of_initial_states=2, + ... number_of_final_states=5, + ... ) + >>> topology + Topology(nodes=..., edges...) + >>> len(topology.nodes) + 1 + >>> len(topology.incoming_edge_ids) + 2 + >>> len(topology.outgoing_edge_ids) + 5 + """ n_in = number_of_initial_states n_out = number_of_final_states builder = SimpleStateTransitionTopologyBuilder( @@ -619,10 +787,10 @@ def create_n_body_topology( def _attach_node_to_edges( - graph: Tuple[_MutableTopology, Sequence[int]], + graph: Tuple[MutableTopology, Sequence[int]], interaction_node: InteractionNode, ingoing_edge_ids: Iterable[int], -) -> Tuple[_MutableTopology, List[int]]: +) -> Tuple[MutableTopology, List[int]]: temp_graph = copy.deepcopy(graph[0]) new_open_end_lines = list(copy.deepcopy(graph[1])) @@ -653,105 +821,173 @@ def _attach_node_to_edges( EdgeType = TypeVar("EdgeType") -"""A `~typing.TypeVar` representing the type of edge properties.""" +NodeType = TypeVar("NodeType") +NewEdgeType = TypeVar("NewEdgeType") +NewNodeType = TypeVar("NewNodeType") + +class Transition(ABC, Generic[EdgeType, NodeType]): + """Mapping of edge and node properties over a `.Topology`. -class StateTransitionGraph(Generic[EdgeType]): - """Graph class that resembles a frozen `.Topology` with properties. + This **interface** class describes a transition from an initial state to a + final state by providing a mapping of properties over the `~Topology.edges` + and `~Topology.nodes` of its `topology`. Since a `Topology` behaves like a + Feynman graph, **edges** are considered as "`states`" and **nodes** are + considered as `interactions` between those states. - This class should contain the full information of a state transition from a - initial state to a final state. This information can be attached to the - nodes and edges via properties. In case not all information is provided, - error can be raised on property retrieval. + There are two implementation classes: + + - `FrozenTransition`: a complete, hashable and ordered mapping of + properties over the `~Topology.edges` and `~Topology.nodes` in its + `~FrozenTransition.topology`. + - `MutableTransition`: comparable to `MutableTopology` in that it is used + internally when finding solutions through the `.StateTransitionManager` + etc. + + These classes are also provided with **mixin** attributes `initial_states`, + `final_states`, `intermediate_states`, and :meth:`filter_states`. """ - def __init__( - self, - topology: Topology, - node_props: Mapping[int, InteractionProperties], - edge_props: Mapping[int, EdgeType], - ): - self.__node_props = dict(node_props) - self.__edge_props = dict(edge_props) - if not isinstance(topology, Topology): - raise TypeError - self.topology = topology - - def __post_init__(self) -> None: - _assert_over_defined(self.topology.nodes, self.__node_props) - _assert_over_defined(self.topology.edges, self.__edge_props) - - def __eq__(self, other: object) -> bool: - """Check if two `.StateTransitionGraph` instances are **identical**.""" - if isinstance(other, StateTransitionGraph): - if self.topology != other.topology: - return False - for i in self.topology.edges: - if self.get_edge_props(i) != other.get_edge_props(i): - return False - for i in self.topology.nodes: - if self.get_node_props(i) != other.get_node_props(i): - return False - return True - raise NotImplementedError( - f"Cannot compare {self.__class__.__name__}" - f" with {other.__class__.__name__}" - ) + @property + @abstractmethod + def topology(self) -> Topology: + """`Topology` over which `states` and `interactions` are defined.""" + ... + + @property + @abstractmethod + def states(self) -> Mapping[int, EdgeType]: + """Mapping of properties over its `topology` `~Topology.edges`.""" + ... + + @property + @abstractmethod + def interactions(self) -> Mapping[int, NodeType]: + """Mapping of properties over its `topology` `~Topology.nodes`.""" + ... + + @property + def initial_states(self) -> Dict[int, EdgeType]: + """Properties for the `~Topology.incoming_edge_ids`.""" + return self.filter_states(self.topology.incoming_edge_ids) - def get_node_props(self, node_id: int) -> InteractionProperties: - return self.__node_props[node_id] + @property + def final_states(self) -> Dict[int, EdgeType]: + """Properties for the `~Topology.outgoing_edge_ids`.""" + return self.filter_states(self.topology.outgoing_edge_ids) - def get_edge_props(self, edge_id: int) -> EdgeType: - return self.__edge_props[edge_id] + @property + def intermediate_states(self) -> Dict[int, EdgeType]: + """Properties for the intermediate edges (connecting two nodes).""" + return self.filter_states(self.topology.intermediate_edge_ids) - def evolve( + def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, EdgeType]: + """Filter `states` by a selection of :code:`edge_ids`.""" + return {i: self.states[i] for i in edge_ids} + + +@implement_pretty_repr +@frozen(order=True) +class FrozenTransition(Transition, Generic[EdgeType, NodeType]): + """Defines a frozen mapping of edge and node properties on a `Topology`.""" + + topology: Topology = field(validator=instance_of(Topology)) + states: FrozenDict[int, EdgeType] = field(converter=FrozenDict) + interactions: FrozenDict[int, NodeType] = field(converter=FrozenDict) + + def __attrs_post_init__(self) -> None: + _assert_all_defined(self.topology.nodes, self.interactions) + _assert_all_defined(self.topology.edges, self.states) + + def unfreeze(self) -> "MutableTransition[EdgeType, NodeType]": + """Convert into a `MutableTransition`.""" + return MutableTransition(self.topology, self.states, self.interactions) + + @overload + def convert(self) -> "FrozenTransition[EdgeType, NodeType]": + ... + + @overload + def convert( + self, state_converter: Callable[[EdgeType], NewEdgeType] + ) -> "FrozenTransition[NewEdgeType, NodeType]": + ... + + @overload + def convert( + self, *, interaction_converter: Callable[[NodeType], NewNodeType] + ) -> "FrozenTransition[EdgeType, NewNodeType]": + ... + + @overload + def convert( self, - node_props: Optional[Dict[int, InteractionProperties]] = None, - edge_props: Optional[Dict[int, EdgeType]] = None, - ) -> "StateTransitionGraph[EdgeType]": - """Changes the node and edge properties of a graph instance. + state_converter: Callable[[EdgeType], NewEdgeType], + interaction_converter: Callable[[NodeType], NewNodeType], + ) -> "FrozenTransition[NewEdgeType, NewNodeType]": + ... - Since a `.StateTransitionGraph` is frozen (cannot be modified), the - evolve function will also create a shallow copy the properties. - """ - new_node_props = copy.copy(self.__node_props) - if node_props: - _assert_over_defined(self.topology.nodes, node_props) - for node_id, node_prop in node_props.items(): - new_node_props[node_id] = node_prop - - new_edge_props = copy.copy(self.__edge_props) - if edge_props: - _assert_over_defined(self.topology.edges, edge_props) - for edge_id, edge_prop in edge_props.items(): - new_edge_props[edge_id] = edge_prop - - return StateTransitionGraph[EdgeType]( - topology=self.topology, - node_props=new_node_props, - edge_props=new_edge_props, + def convert(self, state_converter=None, interaction_converter=None): # type: ignore[no-untyped-def] + """Cast the edge and/or node properties to another type.""" + # pylint: disable=unnecessary-lambda + if state_converter is None: + state_converter = lambda _: _ + if interaction_converter is None: + interaction_converter = lambda _: _ + return FrozenTransition( + self.topology, + states={ + i: state_converter(state) for i, state in self.states.items() + }, + interactions={ + i: interaction_converter(interaction) + for i, interaction in self.interactions.items() + }, ) + +def _cast_states(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: + return dict(obj) + + +def _cast_interactions(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: + return dict(obj) + + +@implement_pretty_repr +@define +class MutableTransition(Transition, Generic[EdgeType, NodeType]): + """Mutable implementation of a `Transition`. + + Mainly used internally by the `.StateTransitionManager` to build solutions. + """ + + topology: Topology = field(validator=instance_of(Topology)) + states: Dict[int, EdgeType] = field(converter=_cast_states, factory=dict) + interactions: Dict[int, NodeType] = field( + converter=_cast_interactions, factory=dict + ) + def compare( self, - other: "StateTransitionGraph", - edge_comparator: Optional[Callable[[EdgeType, EdgeType], bool]] = None, - node_comparator: Optional[ - Callable[[InteractionProperties, InteractionProperties], bool] + other: "MutableTransition", + state_comparator: Optional[ + Callable[[EdgeType, EdgeType], bool] + ] = None, + interaction_comparator: Optional[ + Callable[[NodeType, NodeType], bool] ] = None, ) -> bool: if self.topology != other.topology: return False - if edge_comparator is not None: + if state_comparator is not None: for i in self.topology.edges: - if not edge_comparator( - self.get_edge_props(i), other.get_edge_props(i) - ): + if not state_comparator(self.states[i], other.states[i]): return False - if node_comparator is not None: + if interaction_comparator is not None: for i in self.topology.nodes: - if not node_comparator( - self.get_node_props(i), other.get_node_props(i) + if not interaction_comparator( + self.interactions[i], other.interactions[i] ): return False return True @@ -760,20 +996,35 @@ def swap_edges(self, edge_id1: int, edge_id2: int) -> None: self.topology = self.topology.swap_edges(edge_id1, edge_id2) value1: Optional[EdgeType] = None value2: Optional[EdgeType] = None - if edge_id1 in self.__edge_props: - value1 = self.__edge_props.pop(edge_id1) - if edge_id2 in self.__edge_props: - value2 = self.__edge_props.pop(edge_id2) + if edge_id1 in self.states: + value1 = self.states.pop(edge_id1) + if edge_id2 in self.states: + value2 = self.states.pop(edge_id2) if value1 is not None: - self.__edge_props[edge_id2] = value1 + self.states[edge_id2] = value1 if value2 is not None: - self.__edge_props[edge_id1] = value2 + self.states[edge_id1] = value2 + + def freeze(self) -> "FrozenTransition[EdgeType, NodeType]": + """Convert into a `FrozenTransition`.""" + return FrozenTransition(self.topology, self.states, self.interactions) -def _assert_over_defined(items: Collection, properties: Mapping) -> None: +def _assert_all_defined(items: Iterable, properties: Iterable) -> None: + existing = set(items) defined = set(properties) + if existing & defined != existing: + raise ValueError( + "Some items have no property assigned to them." + f" Available items: {existing}, items with property: {defined}" + ) + + +# pyright: reportUnusedFunction=false +def _assert_not_overdefined(items: Iterable, properties: Iterable) -> None: existing = set(items) - over_defined = existing & defined ^ defined + defined = set(properties) + over_defined = defined - existing if over_defined: raise ValueError( "Properties have been defined for items that don't exist." diff --git a/src/qrules/transition.py b/src/qrules/transition.py index ff671ab8..64c5f458 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -1,16 +1,16 @@ """Find allowed transitions between an initial and final state.""" import logging +import sys from collections import defaultdict from copy import copy, deepcopy from enum import Enum, auto from multiprocessing import Pool from typing import ( - Collection, + TYPE_CHECKING, Dict, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -66,7 +66,6 @@ CSPSolver, EdgeSettings, GraphEdgePropertyMap, - GraphElementProperties, GraphSettings, NodeSettings, QNProblemSet, @@ -74,12 +73,19 @@ ) from .topology import ( FrozenDict, - StateTransitionGraph, + MutableTransition, Topology, create_isobar_topologies, create_n_body_topology, ) +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias +if TYPE_CHECKING: + from .topology import FrozenTransition # noqa: F401 + class SolvingMode(Enum): """Types of modes for solving.""" @@ -138,7 +144,7 @@ def clear(self) -> None: class _SolutionContainer: """Defines a result of a `.ProblemSet`.""" - solutions: List[StateTransitionGraph[ParticleWithSpin]] = field( + solutions: "List[MutableTransition[ParticleWithSpin, InteractionProperties]]" = field( factory=list ) execution_info: ExecutionInfo = field(default=ExecutionInfo()) @@ -167,41 +173,42 @@ def extend( ) +if sys.version_info >= (3, 7): + attrs.resolve_types(_SolutionContainer, globals(), locals()) + + @implement_pretty_repr @define class ProblemSet: - """Particle reaction problem set, defined as a graph like data structure. - - Args: - topology: `.Topology` that contains the structure of the reaction. - initial_facts: `~.InitialFacts` that contain the info of initial and - final state in connection with the topology. - solving_settings: Solving related settings such as the conservation - rules and the quantum number domains. - """ + """Particle reaction problem set as a graph-like data structure.""" topology: Topology - initial_facts: InitialFacts - solving_settings: GraphSettings + """`.Topology` over which the problem set is defined.""" + initial_facts: "InitialFacts" + """Information about the initial and final state.""" + solving_settings: "GraphSettings" + """Solving settings, such as conservation rules and QN-domains.""" def to_qn_problem_set(self) -> QNProblemSet: - node_props = { + interactions = { k: create_node_properties(v) - for k, v in self.initial_facts.node_props.items() + for k, v in self.initial_facts.interactions.items() } - edge_props = { + states = { k: create_edge_properties(v[0], v[1]) - for k, v in self.initial_facts.edge_props.items() + for k, v in self.initial_facts.states.items() } return QNProblemSet( - topology=self.topology, - initial_facts=GraphElementProperties( - node_props=node_props, edge_props=edge_props + initial_facts=MutableTransition( + self.topology, states, interactions ), solving_settings=self.solving_settings, ) +attrs.resolve_types(ProblemSet, globals(), locals()) + + def _group_by_strength( problem_sets: List[ProblemSet], ) -> Dict[float, List[ProblemSet]]: @@ -218,7 +225,7 @@ def calculate_strength( ) for problem_set in problem_sets: strength = calculate_strength( - problem_set.solving_settings.node_settings + problem_set.solving_settings.interactions ) strength_sorted_problem_sets[strength].append(problem_set) return strength_sorted_problem_sets @@ -414,7 +421,7 @@ def create_problem_sets(self) -> Dict[float, List[ProblemSet]]: return _group_by_strength(problem_sets) def __determine_graph_settings( - self, topology: Topology, initial_facts: InitialFacts + self, topology: Topology, initial_facts: "InitialFacts" ) -> List[GraphSettings]: # pylint: disable=too-many-locals def create_intermediate_edge_qn_domains() -> Dict: @@ -464,12 +471,12 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: initial_state_edges = topology.incoming_edge_ids graph_settings: List[GraphSettings] = [ - GraphSettings( - edge_settings={ + MutableTransition( + topology, + states={ edge_id: create_edge_settings(edge_id) for edge_id in topology.edges }, - node_settings={}, ) ] @@ -477,24 +484,24 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: interaction_types: List[InteractionType] = [] out_edge_ids = topology.get_edge_ids_outgoing_from_node(node_id) in_edge_ids = topology.get_edge_ids_outgoing_from_node(node_id) - in_edge_props = [ - initial_facts.edge_props[edge_id] + in_states = [ + initial_facts.states[edge_id] for edge_id in [ x for x in in_edge_ids if x in initial_state_edges ] ] - out_edge_props = [ - initial_facts.edge_props[edge_id] + out_states = [ + initial_facts.states[edge_id] for edge_id in [ x for x in out_edge_ids if x in final_state_edges ] ] - node_props = InteractionProperties() - if node_id in initial_facts.node_props: - node_props = initial_facts.node_props[node_id] + interactions = InteractionProperties() + if node_id in initial_facts.interactions: + interactions = initial_facts.interactions[node_id] for int_det in self.interaction_determinators: determined_interactions = int_det.check( - in_edge_props, out_edge_props, node_props + in_states, out_states, interactions ) if interaction_types: interaction_types = list( @@ -516,7 +523,7 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: for temp_setting in temp_graph_settings: for int_type in interaction_types: updated_setting = deepcopy(temp_setting) - updated_setting.node_settings[node_id] = deepcopy( + updated_setting.interactions[node_id] = deepcopy( self.interaction_type_settings[int_type][1] ) graph_settings.append(updated_setting) @@ -551,7 +558,7 @@ def find_solutions( # pylint: disable=too-many-branches qn_problems = [x.to_qn_problem_set() for x in problems] # Because of pickling problems of Generic classes (in this case - # StateTransitionGraph), multithreaded code has to work with + # MutableTransition), multithreaded code has to work with # QNProblemSet's and QNResult's. So the appropriate conversions # have to be done before and after temp_qn_results: List[Tuple[QNProblemSet, QNResult]] = [] @@ -635,7 +642,11 @@ def find_solutions( # pylint: disable=too-many-branches _match_final_state_ids(graph, self.final_state) for graph in final_solutions ] - return ReactionInfo.from_graphs(final_solutions, self.formalism) + transitions = [ + graph.freeze().convert(lambda s: State(*s)) + for graph in final_solutions + ] + return ReactionInfo(transitions, self.formalism) def _solve( self, qn_problem_set: QNProblemSet @@ -654,15 +665,15 @@ def __convert_result( """ solutions = [] for solution in qn_result.solutions: - graph = StateTransitionGraph[ParticleWithSpin]( + graph = MutableTransition( topology=topology, - node_props={ + interactions={ i: create_interaction_properties(x) - for i, x in solution.node_quantum_numbers.items() + for i, x in solution.interactions.items() }, - edge_props={ + states={ i: create_particle(x, self.__particles) - for i, x in solution.edge_quantum_numbers.items() + for i, x in solution.states.items() }, ) solutions.append(graph) @@ -692,24 +703,24 @@ def _safe_wrap_list( def _match_final_state_ids( - graph: StateTransitionGraph[ParticleWithSpin], + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", state_definition: Sequence[StateDefinition], -) -> StateTransitionGraph[ParticleWithSpin]: +) -> "MutableTransition[ParticleWithSpin, InteractionProperties]": """Temporary fix to https://github.com/ComPWA/qrules/issues/143.""" particle_names = _strip_spin(state_definition) name_to_id = {name: i for i, name in enumerate(particle_names)} id_remapping = { - name_to_id[graph.get_edge_props(i)[0].name]: i + name_to_id[graph.states[i][0].name]: i for i in graph.topology.outgoing_edge_ids } new_topology = graph.topology.relabel_edges(id_remapping) - return StateTransitionGraph( + return MutableTransition( new_topology, - edge_props={ - i: graph.get_edge_props(id_remapping.get(i, i)) + states={ + i: graph.states[id_remapping.get(i, i)] for i in graph.topology.edges }, - node_props={i: graph.get_node_props(i) for i in graph.topology.nodes}, + interactions={i: graph.interactions[i] for i in graph.topology.nodes}, ) @@ -730,85 +741,13 @@ class State: spin_projection: float = field(converter=_to_float) -@implement_pretty_repr -@frozen(order=True) -class StateTransition: - """Frozen instance of a `.StateTransitionGraph` of a particle with spin.""" - - topology: Topology = field(validator=instance_of(Topology)) - states: FrozenDict[int, State] = field(converter=FrozenDict) - interactions: FrozenDict[int, InteractionProperties] = field( - converter=FrozenDict - ) - - def __attrs_post_init__(self) -> None: - _assert_defined(self.topology.edges, self.states) - _assert_defined(self.topology.nodes, self.interactions) - - @staticmethod - def from_graph( - graph: StateTransitionGraph[ParticleWithSpin], - ) -> "StateTransition": - return StateTransition( - topology=graph.topology, - states=FrozenDict( - { - i: State(*graph.get_edge_props(i)) - for i in graph.topology.edges - } - ), - interactions=FrozenDict( - {i: graph.get_node_props(i) for i in graph.topology.nodes} - ), - ) - - def to_graph(self) -> StateTransitionGraph[ParticleWithSpin]: - return StateTransitionGraph[ParticleWithSpin]( - topology=self.topology, - edge_props={ - i: (state.particle, state.spin_projection) - for i, state in self.states.items() - }, - node_props=self.interactions, - ) - - @property - def initial_states(self) -> Dict[int, State]: - return self.filter_states(self.topology.incoming_edge_ids) - - @property - def final_states(self) -> Dict[int, State]: - return self.filter_states(self.topology.outgoing_edge_ids) - - @property - def intermediate_states(self) -> Dict[int, State]: - return self.filter_states(self.topology.intermediate_edge_ids) - - def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, State]: - return {i: self.states[i] for i in edge_ids} - - @property - def particles(self) -> Dict[int, Particle]: - return {i: edge_prop.particle for i, edge_prop in self.states.items()} - - -def _assert_defined(items: Collection, properties: Mapping) -> None: - existing = set(items) - defined = set(properties) - if existing & defined != existing: - raise ValueError( - "Some items have no property assigned to them." - f" Available items: {existing}, items with property: {defined}" - ) +StateTransition: TypeAlias = "FrozenTransition[State, InteractionProperties]" +"""Transition of some initial `.State` to a final `.State`.""" def _sort_tuple( iterable: Iterable[StateTransition], ) -> Tuple[StateTransition, ...]: - if not all(map(lambda t: isinstance(t, StateTransition), iterable)): - raise TypeError( - f"Not all instances are of type {StateTransition.__name__}" - ) return tuple(sorted(iterable)) @@ -843,17 +782,6 @@ def get_intermediate_particles(self) -> ParticleCollection: } return ParticleCollection(particles) - @staticmethod - def from_graphs( - graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], - formalism: str, - ) -> "ReactionInfo": - transitions = [StateTransition.from_graph(g) for g in graphs] - return ReactionInfo(transitions, formalism) - - def to_graphs(self) -> List[StateTransitionGraph[ParticleWithSpin]]: - return [transition.to_graph() for transition in self.transitions] - def group_by_topology(self) -> Dict[Topology, List[StateTransition]]: groupings = defaultdict(list) for transition in self.transitions: diff --git a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py index 7ec5dc7e..ef3726d9 100644 --- a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py +++ b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py @@ -2,8 +2,6 @@ import qrules from qrules.combinatorics import _create_edge_id_particle_mapping -from qrules.particle import ParticleWithSpin -from qrules.topology import StateTransitionGraph @pytest.mark.parametrize( @@ -62,7 +60,9 @@ def test_id_to_particle_mappings(particle_database): assert len(reaction.transitions) == 4 iter_transitions = iter(reaction.transitions) first_transition = next(iter_transitions) - graph: StateTransitionGraph[ParticleWithSpin] = first_transition.to_graph() + graph = first_transition.convert( + lambda s: (s.particle, s.spin_projection) + ).unfreeze() ref_mapping_fs = _create_edge_id_particle_mapping( graph, graph.topology.outgoing_edge_ids ) @@ -70,7 +70,9 @@ def test_id_to_particle_mappings(particle_database): graph, graph.topology.incoming_edge_ids ) for transition in iter_transitions: - graph = transition.to_graph() + graph = transition.convert( + lambda s: (s.particle, s.spin_projection) + ).unfreeze() assert ref_mapping_fs == _create_edge_id_particle_mapping( graph, graph.topology.outgoing_edge_ids ) diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index b2f49636..4d620fc2 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -9,7 +9,7 @@ _get_particle_graphs, _strip_projections, ) -from qrules.particle import ParticleCollection +from qrules.particle import Particle, ParticleCollection from qrules.topology import ( Edge, Topology, @@ -163,7 +163,11 @@ def test_write_topology(self, output_dir): output_file = output_dir + "two_body_decay_topology.gv" topology = Topology( nodes={0}, - edges={0: Edge(0, None), 1: Edge(None, 0), 2: Edge(None, 0)}, + edges={ + -1: Edge(None, 0), + 0: Edge(0, None), + 1: Edge(0, None), + }, ) io.write( instance=topology, @@ -208,15 +212,17 @@ def test_collapse_graphs( particle_database: ParticleCollection, ): pdg = particle_database - particle_graphs = _get_particle_graphs(reaction.to_graphs()) + particle_graphs = _get_particle_graphs(reaction.transitions) # type: ignore[arg-type] assert len(particle_graphs) == 2 - collapsed_graphs = _collapse_graphs(reaction.to_graphs()) + + collapsed_graphs = _collapse_graphs(reaction.transitions) # type: ignore[arg-type] assert len(collapsed_graphs) == 1 graph = next(iter(collapsed_graphs)) edge_id = next(iter(graph.topology.intermediate_edge_ids)) f_resonances = pdg.filter(lambda p: p.name in ["f(0)(980)", "f(0)(1500)"]) - intermediate_states = graph.get_edge_props(edge_id) - assert isinstance(intermediate_states, ParticleCollection) + intermediate_states = graph.states[edge_id] + assert isinstance(intermediate_states, tuple) + assert all(map(lambda i: isinstance(i, Particle), intermediate_states)) assert intermediate_states == f_resonances @@ -224,15 +230,13 @@ def test_get_particle_graphs( reaction: ReactionInfo, particle_database: ParticleCollection ): pdg = particle_database - particle_graphs = _get_particle_graphs(reaction.to_graphs()) - assert len(particle_graphs) == 2 - assert particle_graphs[0].get_edge_props(3) == pdg["f(0)(980)"] - assert particle_graphs[1].get_edge_props(3) == pdg["f(0)(1500)"] - assert len(particle_graphs[0].topology.edges) == 5 - for edge_id in range(-1, 3): - assert particle_graphs[0].get_edge_props(edge_id) is particle_graphs[ - 1 - ].get_edge_props(edge_id) + graphs = _get_particle_graphs(reaction.transitions) # type: ignore[arg-type] + assert len(graphs) == 2 + assert graphs[0].states[3] == pdg["f(0)(980)"] + assert graphs[1].states[3] == pdg["f(0)(1500)"] + assert len(graphs[0].topology.edges) == 5 + for i in range(-1, 3): + assert graphs[0].states[i] is graphs[1].states[i] def test_strip_projections(): @@ -256,8 +260,8 @@ def test_strip_projections(): assert transition.interactions[1].l_projection == 0 stripped_transition = _strip_projections(transition) # type: ignore[arg-type] - assert stripped_transition.get_edge_props(3).name == resonance - assert stripped_transition.get_node_props(0).s_projection is None - assert stripped_transition.get_node_props(0).l_projection is None - assert stripped_transition.get_node_props(1).s_projection is None - assert stripped_transition.get_node_props(1).l_projection is None + assert stripped_transition.states[3].name == resonance + assert stripped_transition.interactions[0].s_projection is None + assert stripped_transition.interactions[0].l_projection is None + assert stripped_transition.interactions[1].s_projection is None + assert stripped_transition.interactions[1].l_projection is None diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index be7b8638..8bc97ec5 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -5,7 +5,7 @@ from qrules import io from qrules.particle import Particle, ParticleCollection from qrules.topology import ( - StateTransitionGraph, + FrozenTransition, Topology, create_isobar_topologies, create_n_body_topology, @@ -44,10 +44,10 @@ def test_asdict_fromdict(particle_selection: ParticleCollection): def test_asdict_fromdict_reaction(reaction: ReactionInfo): - # StateTransitionGraph - for graph in reaction.to_graphs(): + # FrozenTransition + for graph in reaction.transitions: fromdict = through_dict(graph) - assert isinstance(fromdict, StateTransitionGraph) + assert isinstance(fromdict, FrozenTransition) assert graph == fromdict # ReactionInfo fromdict = through_dict(reaction) diff --git a/tests/unit/test_combinatorics.py b/tests/unit/test_combinatorics.py index a2339770..258eda37 100644 --- a/tests/unit/test_combinatorics.py +++ b/tests/unit/test_combinatorics.py @@ -98,14 +98,14 @@ def test_constructor(self): def test_from_topology(self, three_body_decay: Topology): pi0 = ("pi0", [0]) gamma = ("gamma", [-1, 1]) - edge_props = { + states = { -1: ("J/psi", [-1, +1]), 0: pi0, 1: pi0, 2: gamma, } kinematic_representation1 = _get_kinematic_representation( - three_body_decay, edge_props + three_body_decay, states ) assert kinematic_representation1.initial_state == [ ["J/psi"], diff --git a/tests/unit/test_system_control.py b/tests/unit/test_system_control.py index a537550b..741f69f2 100644 --- a/tests/unit/test_system_control.py +++ b/tests/unit/test_system_control.py @@ -1,7 +1,8 @@ # pylint: disable=protected-access from copy import deepcopy -from typing import List +from typing import Dict, List +import attrs import pytest from qrules import InteractionType, ProblemSet, StateTransitionManager @@ -23,7 +24,7 @@ InteractionProperties, NodeQuantumNumbers, ) -from qrules.topology import Edge, StateTransitionGraph, Topology +from qrules.topology import Edge, MutableTransition, Topology @pytest.mark.parametrize( @@ -207,38 +208,36 @@ def test_create_edge_properties( def make_ls_test_graph( angular_momentum_magnitude, coupled_spin_magnitude, particle ): - graph = StateTransitionGraph[ParticleWithSpin]( - topology=Topology( - nodes={0}, - edges={0: Edge(None, 0)}, - ), - node_props={ - 0: InteractionProperties( - s_magnitude=coupled_spin_magnitude, - l_magnitude=angular_momentum_magnitude, - ) - }, - edge_props={0: (particle, 0)}, + topology = Topology( + nodes={0}, + edges={-1: Edge(None, 0)}, ) + interactions = { + 0: InteractionProperties( + s_magnitude=coupled_spin_magnitude, + l_magnitude=angular_momentum_magnitude, + ) + } + states: Dict[int, ParticleWithSpin] = {-1: (particle, 0)} + graph = MutableTransition(topology, states, interactions) return graph def make_ls_test_graph_scrambled( angular_momentum_magnitude, coupled_spin_magnitude, particle ): - graph = StateTransitionGraph[ParticleWithSpin]( - topology=Topology( - nodes={0}, - edges={0: Edge(None, 0)}, - ), - node_props={ - 0: InteractionProperties( - l_magnitude=angular_momentum_magnitude, - s_magnitude=coupled_spin_magnitude, - ) - }, - edge_props={0: (particle, 0)}, + topology = Topology( + nodes={0}, + edges={-1: Edge(None, 0)}, ) + interactions = { + 0: InteractionProperties( + l_magnitude=angular_momentum_magnitude, + s_magnitude=coupled_spin_magnitude, + ) + } + states: Dict[int, ParticleWithSpin] = {-1: (particle, 0)} + graph = MutableTransition(topology, states, interactions) return graph @@ -324,13 +323,14 @@ def test_filter_graphs_for_interaction_qns( for value in input_values: tempgraph = make_ls_test_graph(value[1][0], value[1][1], pi0) - tempgraph = tempgraph.evolve( - edge_props={ - 0: ( + tempgraph = attrs.evolve( + tempgraph, + states={ + -1: ( Particle(name=value[0], pid=0, mass=1.0, spin=1.0), 0.0, ) - } + }, ) graphs.append(tempgraph) @@ -341,11 +341,11 @@ def test_filter_graphs_for_interaction_qns( def _create_graph( problem_set: ProblemSet, -) -> StateTransitionGraph[ParticleWithSpin]: - return StateTransitionGraph[ParticleWithSpin]( +) -> "MutableTransition[ParticleWithSpin, InteractionProperties]": + return MutableTransition( topology=problem_set.topology, - node_props=problem_set.initial_facts.node_props, - edge_props=problem_set.initial_facts.edge_props, + interactions=problem_set.initial_facts.interactions, + states=problem_set.initial_facts.states, ) @@ -369,7 +369,9 @@ def test_edge_swap(particle_database, initial_state, final_state): stm.set_allowed_interaction_types([InteractionType.STRONG]) problem_sets = stm.create_problem_sets() - init_graphs: List[StateTransitionGraph[ParticleWithSpin]] = [] + init_graphs: List[ + MutableTransition[ParticleWithSpin, InteractionProperties] + ] = [] for _, problem_set_list in problem_sets.items(): init_graphs.extend([_create_graph(x) for x in problem_set_list]) @@ -380,15 +382,15 @@ def test_edge_swap(particle_database, initial_state, final_state): edge_keys = list(ref_mapping.keys()) edge1 = edge_keys[0] edge1_val = graph.topology.edges[edge1] - edge1_props = deepcopy(graph.get_edge_props(edge1)) + edge1_props = deepcopy(graph.states[edge1]) edge2 = edge_keys[1] edge2_val = graph.topology.edges[edge2] - edge2_props = deepcopy(graph.get_edge_props(edge2)) + edge2_props = deepcopy(graph.states[edge2]) graph.swap_edges(edge1, edge2) assert graph.topology.edges[edge1] == edge2_val assert graph.topology.edges[edge2] == edge1_val - assert graph.get_edge_props(edge1) == edge2_props - assert graph.get_edge_props(edge2) == edge1_props + assert graph.states[edge1] == edge2_props + assert graph.states[edge2] == edge1_props @pytest.mark.parametrize( @@ -415,7 +417,9 @@ def test_match_external_edges(particle_database, initial_state, final_state): stm.set_allowed_interaction_types([InteractionType.STRONG]) problem_sets = stm.create_problem_sets() - init_graphs: List[StateTransitionGraph[ParticleWithSpin]] = [] + init_graphs: List[ + MutableTransition[ParticleWithSpin, InteractionProperties] + ] = [] for _, problem_set_list in problem_sets.items(): init_graphs.extend([_create_graph(x) for x in problem_set_list]) @@ -504,7 +508,9 @@ def test_external_edge_identical_particle_combinatorics( match_external_edges(init_graphs) - comb_graphs: List[StateTransitionGraph[ParticleWithSpin]] = [] + comb_graphs: List[ + MutableTransition[ParticleWithSpin, InteractionProperties] + ] = [] for group in init_graphs: comb_graphs.extend( perform_external_edge_identical_particle_combinatorics(group) diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 688ab640..0acaac5f 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -10,9 +10,9 @@ Edge, FrozenDict, InteractionNode, + MutableTopology, SimpleStateTransitionTopologyBuilder, Topology, - _MutableTopology, create_isobar_topologies, create_n_body_topology, get_originating_node_list, @@ -27,20 +27,20 @@ def two_to_three_decay() -> Topology: .. code-block:: - e0 -- (N0) -- e2 -- (N1) -- e3 -- (N2) -- e6 + e-1 -- (N0) -- e3 -- (N1) -- e4 -- (N2) -- e2 / \ \ - e1 e4 e5 + e-2 e0 e1 """ topology = Topology( nodes={0, 1, 2}, edges={ - 0: Edge(None, 0), - 1: Edge(None, 0), - 2: Edge(0, 1), - 3: Edge(1, 2), - 4: Edge(1, None), - 5: Edge(2, None), - 6: Edge(2, None), + -2: Edge(None, 0), + -1: Edge(None, 0), + 0: Edge(1, None), + 1: Edge(2, None), + 2: Edge(2, None), + 3: Edge(0, 1), + 4: Edge(1, 2), }, ) return topology @@ -100,42 +100,59 @@ def test_constructor_exceptions(self): class TestMutableTopology: def test_add_and_attach(self, two_to_three_decay: Topology): - topology = _MutableTopology( + topology = MutableTopology( edges=two_to_three_decay.edges, - nodes=two_to_three_decay.nodes, # type: ignore[arg-type] + nodes=two_to_three_decay.nodes, ) topology.add_node(3) - topology.add_edges([7, 8]) - topology.attach_edges_to_node_outgoing([7, 8], 3) + topology.add_edges([5, 6]) + topology.attach_edges_to_node_outgoing([5, 6], 3) with pytest.raises( ValueError, match=r"Node 3 is not connected to any other node", ): topology.freeze() - topology.attach_edges_to_node_ingoing([6], 3) - assert isinstance(topology.freeze(), Topology) + topology.attach_edges_to_node_ingoing([2], 3) + assert isinstance(topology.organize_edge_ids().freeze(), Topology) def test_add_exceptions(self, two_to_three_decay: Topology): - topology = _MutableTopology( + topology = MutableTopology( edges=two_to_three_decay.edges, - nodes=two_to_three_decay.nodes, # type: ignore[arg-type] + nodes=two_to_three_decay.nodes, ) with pytest.raises(ValueError, match=r"Node nr. 0 already exists"): topology.add_node(0) with pytest.raises(ValueError, match=r"Edge nr. 0 already exists"): topology.add_edges([0]) with pytest.raises( - ValueError, match=r"Edge nr. 0 is already ingoing to node 0" + ValueError, match=r"Edge nr. -2 is already ingoing to node 0" ): - topology.attach_edges_to_node_ingoing([0], 0) + topology.attach_edges_to_node_ingoing([-2], 0) with pytest.raises(ValueError, match=r"Edge nr. 7 does not exist"): topology.attach_edges_to_node_ingoing([7], 2) with pytest.raises( - ValueError, match=r"Edge nr. 6 is already outgoing from node 2" + ValueError, match=r"Edge nr. 2 is already outgoing from node 2" ): - topology.attach_edges_to_node_outgoing([6], 2) - with pytest.raises(ValueError, match=r"Edge nr. 7 does not exist"): - topology.attach_edges_to_node_outgoing([7], 2) + topology.attach_edges_to_node_outgoing([2], 2) + with pytest.raises(ValueError, match=r"Edge nr. 5 does not exist"): + topology.attach_edges_to_node_outgoing([5], 2) + + def test_organize_edge_ids(self): + topology = MutableTopology( + nodes={0, 1, 2}, + edges={ + 0: Edge(None, 0), + 1: Edge(None, 0), + 2: Edge(1, None), + 3: Edge(2, None), + 4: Edge(2, None), + 5: Edge(0, 1), + 6: Edge(1, 2), + }, + ) + assert sorted(topology.edges) == [0, 1, 2, 3, 4, 5, 6] + topology = topology.organize_edge_ids() + assert sorted(topology.edges) == [-2, -1, 0, 1, 2, 3, 4] class TestSimpleStateTransitionTopologyBuilder: @@ -156,28 +173,28 @@ class TestTopology: ( {0, 1}, { - 0: Edge(None, 0), - 1: Edge(0, 1), - 2: Edge(1, None), - 3: Edge(1, None), + -1: Edge(None, 0), + 0: Edge(1, None), + 1: Edge(1, None), + 2: Edge(0, 1), }, ), ( {0, 1, 2}, { - 0: Edge(None, 0), - 1: Edge(0, 1), - 2: Edge(0, 2), - 3: Edge(1, None), - 4: Edge(1, None), - 5: Edge(2, None), - 6: Edge(2, None), + -1: Edge(None, 0), + 0: Edge(1, None), + 1: Edge(1, None), + 2: Edge(2, None), + 3: Edge(2, None), + 4: Edge(0, 1), + 5: Edge(0, 2), }, ), ], ) def test_constructor(self, nodes, edges): - topology = Topology(nodes=nodes, edges=edges) + topology = Topology(nodes, edges) if nodes is None: nodes = set() if edges is None: @@ -202,7 +219,7 @@ def test_constructor_exceptions(self, nodes, edges): r"(not connected to any other node|has non-existing node IDs)" ), ): - assert Topology(nodes=nodes, edges=edges) + assert Topology(nodes, edges) @pytest.mark.parametrize("repr_method", [repr, pretty]) def test_repr_and_eq(self, repr_method, two_to_three_decay: Topology): @@ -212,11 +229,11 @@ def test_repr_and_eq(self, repr_method, two_to_three_decay: Topology): def test_getters(self, two_to_three_decay: Topology): topology = two_to_three_decay # shorter name - assert get_originating_node_list(topology, edge_ids=[0]) == [] - assert get_originating_node_list(topology, edge_ids=[5, 6]) == [2, 2] - assert topology.incoming_edge_ids == {0, 1} - assert topology.outgoing_edge_ids == {4, 5, 6} - assert topology.intermediate_edge_ids == {2, 3} + assert topology.incoming_edge_ids == {-2, -1} + assert topology.outgoing_edge_ids == {0, 1, 2} + assert topology.intermediate_edge_ids == {3, 4} + assert get_originating_node_list(topology, edge_ids=[-1]) == [] + assert get_originating_node_list(topology, edge_ids=[1, 2]) == [2, 2] @typing.no_type_check def test_immutability(self, two_to_three_decay: Topology): @@ -234,43 +251,30 @@ def test_immutability(self, two_to_three_decay: Topology): node += 666 assert two_to_three_decay.nodes == {0, 1, 2} - def test_organize_edge_ids(self, two_to_three_decay: Topology): - topology = two_to_three_decay.organize_edge_ids() - assert topology.incoming_edge_ids == frozenset({-1, -2}) - assert topology.outgoing_edge_ids == frozenset({0, 1, 2}) - assert topology.intermediate_edge_ids == frozenset({3, 4}) - def test_relabel_edges(self, two_to_three_decay: Topology): - assert set(two_to_three_decay.edges) == {0, 1, 2, 3, 4, 5, 6} - relabeled_topology = two_to_three_decay.relabel_edges({0: -2, 1: -1}) - assert set(relabeled_topology.edges) == {-2, -1, 2, 3, 4, 5, 6} - relabeled_topology = relabeled_topology.relabel_edges({2: 0, 3: 1}) - assert set(relabeled_topology.edges) == {-2, -1, 0, 1, 4, 5, 6} + edge_ids = set(two_to_three_decay.edges) + relabeled_topology = two_to_three_decay.relabel_edges({0: 1, 1: 0}) + assert set(relabeled_topology.edges) == edge_ids + relabeled_topology = relabeled_topology.relabel_edges( + {3: 4, 4: 3, 1: 2, 2: 1} + ) + assert set(relabeled_topology.edges) == edge_ids + relabeled_topology = relabeled_topology.relabel_edges({3: 4}) + assert set(relabeled_topology.edges) == edge_ids def test_swap_edges(self, two_to_three_decay: Topology): original_topology = two_to_three_decay - topology = original_topology.swap_edges(0, 1) + topology = original_topology.swap_edges(-2, -1) assert topology == original_topology - topology = topology.swap_edges(5, 6) + topology = topology.swap_edges(1, 2) assert topology == original_topology - topology = topology.swap_edges(4, 6) + topology = topology.swap_edges(0, 1) assert topology != original_topology - @pytest.mark.parametrize( - ("n_final_states", "expected_order"), - [ - (2, [0]), - (3, [0]), - (4, [1, 0]), - (5, [3, 2, 4, 1, 0]), - (6, [13, 9, 7, 10, 6, 5, 14, 11, 8, 15, 3, 2, 12, 4, 1, 0]), - ], - ) - def test_unique_ordering(self, n_final_states, expected_order): + @pytest.mark.parametrize("n_final_states", [2, 3, 4, 5, 6]) + def test_unique_ordering(self, n_final_states): topologies = create_isobar_topologies(n_final_states) - sorted_topologies = sorted(topologies) - order = [sorted_topologies.index(t) for t in topologies] - assert order == expected_order + assert sorted(topologies) == list(topologies) @pytest.mark.parametrize( diff --git a/tests/unit/test_transition.py b/tests/unit/test_transition.py index cdc929af..435e4c4d 100644 --- a/tests/unit/test_transition.py +++ b/tests/unit/test_transition.py @@ -1,7 +1,6 @@ # pyright: reportUnusedImport=false # pylint: disable=eval-used, no-self-use -from operator import itemgetter -from typing import List +from copy import deepcopy import pytest from IPython.lib.pretty import pretty @@ -10,22 +9,17 @@ Parity, Particle, ParticleCollection, - ParticleWithSpin, Spin, ) from qrules.quantum_numbers import InteractionProperties # noqa: F401 from qrules.topology import ( # noqa: F401 Edge, FrozenDict, - StateTransitionGraph, + FrozenTransition, + MutableTransition, Topology, ) -from qrules.transition import State # noqa: F401 -from qrules.transition import ( - ReactionInfo, - StateTransition, - StateTransitionManager, -) +from qrules.transition import ReactionInfo, State, StateTransitionManager class TestReactionInfo: @@ -40,7 +34,7 @@ def test_properties(self, reaction: ReactionInfo): else: assert len(reaction.transitions) == 8 for transition in reaction.transitions: - assert isinstance(transition, StateTransition) + assert isinstance(transition, FrozenTransition) @pytest.mark.parametrize("repr_method", [repr, pretty]) def test_repr(self, repr_method, reaction: ReactionInfo): @@ -48,15 +42,8 @@ def test_repr(self, repr_method, reaction: ReactionInfo): from_repr = eval(repr_method(instance)) assert from_repr == instance - def test_from_to_graphs(self, reaction: ReactionInfo): - graphs = reaction.to_graphs() - from_graphs = ReactionInfo.from_graphs(graphs, reaction.formalism) - assert from_graphs == reaction - def test_hash(self, reaction: ReactionInfo): - graphs = reaction.to_graphs() - from_graphs = ReactionInfo.from_graphs(graphs, reaction.formalism) - assert hash(from_graphs) == hash(reaction) + assert hash(deepcopy(reaction)) == hash(reaction) class TestState: @@ -81,101 +68,6 @@ def create_state(state_def) -> State: assert state2 >= state1 -class TestStateTransition: - def test_ordering(self, reaction: ReactionInfo): - sorted_transitions: List[StateTransition] = sorted( - reaction.transitions - ) - if reaction.formalism.startswith("cano"): - first = sorted_transitions[0] - second = sorted_transitions[1] - assert first.interactions[0].l_magnitude == 0.0 - assert second.interactions[0].l_magnitude == 2.0 - assert first.interactions[1] == second.interactions[1] - transition_selection = sorted_transitions[::2] - else: - transition_selection = sorted_transitions - - simplified_rendering = [ - tuple( - ( - transition.states[state_id].particle.name, - int(transition.states[state_id].spin_projection), - ) - for state_id in sorted(transition.states) - ) - for transition in transition_selection - ] - - assert simplified_rendering[:3] == [ - ( - ("J/psi(1S)", -1), - ("gamma", -1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(980)", 0), - ), - ( - ("J/psi(1S)", -1), - ("gamma", -1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(1500)", 0), - ), - ( - ("J/psi(1S)", -1), - ("gamma", +1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(980)", 0), - ), - ] - assert simplified_rendering[-1] == ( - ("J/psi(1S)", +1), - ("gamma", +1), - ("pi0", 0), - ("pi0", 0), - ("f(0)(1500)", 0), - ) - - # J/psi - first_half = slice(0, int(len(simplified_rendering) / 2)) - for item in simplified_rendering[first_half]: - assert item[0] == ("J/psi(1S)", -1) - second_half = slice(int(len(simplified_rendering) / 2), None) - for item in simplified_rendering[second_half]: - assert item[0] == ("J/psi(1S)", +1) - second_half = slice(int(len(simplified_rendering) / 2), None) - # gamma - for item in itemgetter(0, 1, 4, 5)(simplified_rendering): - assert item[1] == ("gamma", -1) - for item in itemgetter(2, 3, 6, 7)(simplified_rendering): - assert item[1] == ("gamma", +1) - # pi0 - for item in simplified_rendering: - assert item[2] == ("pi0", 0) - assert item[3] == ("pi0", 0) - # f0 - for item in simplified_rendering[::2]: - assert item[4] == ("f(0)(980)", 0) - for item in simplified_rendering[1::2]: - assert item[4] == ("f(0)(1500)", 0) - - @pytest.mark.parametrize("repr_method", [repr, pretty]) - def test_repr(self, repr_method, reaction: ReactionInfo): - for instance in reaction.transitions: - from_repr = eval(repr_method(instance)) - assert from_repr == instance - - def test_from_to_graph(self, reaction: ReactionInfo): - assert len(reaction.group_by_topology()) == 1 - assert len(reaction.transitions) in {8, 16} - for transition in reaction.transitions: - graph = transition.to_graph() - from_graph = StateTransition.from_graph(graph) - assert transition == from_graph - - class TestStateTransitionManager: def test_allowed_intermediate_particles(self): stm = StateTransitionManager( diff --git a/tox.ini b/tox.ini index ace8b21d..c9b53d02 100644 --- a/tox.ini +++ b/tox.ini @@ -54,6 +54,7 @@ commands = --watch src \ --re-ignore .*/.ipynb_checkpoints/.* \ --re-ignore .*/__pycache__/.* \ + --re-ignore .*\.tmp \ --re-ignore docs/.*\.csv \ --re-ignore docs/.*\.gv \ --re-ignore docs/.*\.inv \ @@ -62,6 +63,8 @@ commands = --re-ignore docs/.*\.yaml \ --re-ignore docs/.*\.yml \ --re-ignore docs/_build/.* \ + --re-ignore docs/_images/.* \ + --re-ignore docs/_static/logo\..* \ --re-ignore docs/api/.* \ --open-browser \ docs/ docs/_build/html