From d4dc9b2f7fe946b66bdf76a460a560b6677d035c Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Wed, 23 Feb 2022 14:04:11 +0100 Subject: [PATCH] feat: extract DynamicsSelector class (#240) * build: install singledispatchmethod b * ci: deactivate fail-fast for unit testselow Python 3.8 * ci: test ampform with QRules v0.9.7 * docs: hide `__getitem__()` methods from API by default * docs: hide dict methods from API * feat: allow assigning dynamics by Particle instance * refactor: extract _HelicityModelIngredients class Co-authored-by: GitHub --- .constraints/py3.10.txt | 8 +- .constraints/py3.6.txt | 5 +- .constraints/py3.7.txt | 9 +- .constraints/py3.8.txt | 8 +- .constraints/py3.9.txt | 8 +- .github/workflows/ci-tests.yml | 1 + .pre-commit-config.yaml | 2 +- docs/conf.py | 4 +- setup.cfg | 1 + src/ampform/helicity/__init__.py | 184 +++++++++++++++++++++++-------- src/ampform/helicity/decay.py | 35 +++++- src/symplot/__init__.py | 2 + 12 files changed, 203 insertions(+), 64 deletions(-) diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index 4d39fd5ad..590364161 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -55,7 +55,7 @@ gprof2dot==2021.2.21 graphviz==0.19.1 greenlet==1.1.2 hepunits==2.2.0 -identify==2.4.10 +identify==2.4.11 idna==3.3 imagesize==1.3.0 importlib-metadata==4.11.1 @@ -74,7 +74,7 @@ jupyter-cache==0.4.3 jupyter-client==7.1.2 jupyter-core==4.9.2 jupyter-server==1.13.5 -jupyter-server-mathjax==0.2.4 +jupyter-server-mathjax==0.2.5 jupyter-sphinx==0.3.2 jupyterlab==3.2.9 jupyterlab-code-formatter==1.4.10 @@ -152,7 +152,7 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.6 +qrules==0.9.7 radon==5.1.0 requests==2.27.1 restructuredtext-lint==1.3.2 @@ -193,7 +193,7 @@ tqdm==4.62.3 traitlets==5.1.1 types-docutils==0.17.6 types-pkg-resources==0.1.3 -types-requests==2.27.10 +types-requests==2.27.11 types-setuptools==57.4.9 types-urllib3==1.26.9 typing-extensions==4.1.1 diff --git a/.constraints/py3.6.txt b/.constraints/py3.6.txt index b5427567c..ec8d640b1 100644 --- a/.constraints/py3.6.txt +++ b/.constraints/py3.6.txt @@ -148,12 +148,13 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.6 +qrules==0.9.7 radon==5.1.0 requests==2.27.1 restructuredtext-lint==1.3.2 rich==11.2.0 send2trash==1.8.0 +singledispatchmethod==1.0 ; python_version < "3.8.0" six==1.16.0 smmap==5.0.0 sniffio==1.2.0 @@ -189,7 +190,7 @@ traitlets==4.3.3 typed-ast==1.5.2 types-docutils==0.17.6 types-pkg-resources==0.1.3 -types-requests==2.27.10 +types-requests==2.27.11 types-setuptools==57.4.9 types-urllib3==1.26.9 typing-extensions==4.1.1 ; python_version < "3.8.0" diff --git a/.constraints/py3.7.txt b/.constraints/py3.7.txt index 62d86612f..2d0a4c460 100644 --- a/.constraints/py3.7.txt +++ b/.constraints/py3.7.txt @@ -52,7 +52,7 @@ gprof2dot==2021.2.21 graphviz==0.19.1 greenlet==1.1.2 hepunits==2.2.0 -identify==2.4.10 +identify==2.4.11 idna==3.3 imagesize==1.3.0 importlib-metadata==4.2.0 @@ -72,7 +72,7 @@ jupyter-cache==0.4.3 jupyter-client==7.1.2 jupyter-core==4.9.2 jupyter-server==1.13.5 -jupyter-server-mathjax==0.2.4 +jupyter-server-mathjax==0.2.5 jupyter-sphinx==0.3.2 jupyterlab==3.2.9 jupyterlab-code-formatter==1.4.10 @@ -149,12 +149,13 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.6 +qrules==0.9.7 radon==5.1.0 requests==2.27.1 restructuredtext-lint==1.3.2 rich==11.2.0 send2trash==1.8.0 +singledispatchmethod==1.0 ; python_version < "3.8.0" six==1.16.0 smmap==5.0.0 sniffio==1.2.0 @@ -190,7 +191,7 @@ traitlets==5.1.1 typed-ast==1.5.2 types-docutils==0.17.6 types-pkg-resources==0.1.3 -types-requests==2.27.10 +types-requests==2.27.11 types-setuptools==57.4.9 types-urllib3==1.26.9 typing-extensions==4.1.1 ; python_version < "3.8.0" diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index 7e814c32c..c0023f00a 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -55,7 +55,7 @@ gprof2dot==2021.2.21 graphviz==0.19.1 greenlet==1.1.2 hepunits==2.2.0 -identify==2.4.10 +identify==2.4.11 idna==3.3 imagesize==1.3.0 importlib-metadata==4.11.1 @@ -75,7 +75,7 @@ jupyter-cache==0.4.3 jupyter-client==7.1.2 jupyter-core==4.9.2 jupyter-server==1.13.5 -jupyter-server-mathjax==0.2.4 +jupyter-server-mathjax==0.2.5 jupyter-sphinx==0.3.2 jupyterlab==3.2.9 jupyterlab-code-formatter==1.4.10 @@ -153,7 +153,7 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.6 +qrules==0.9.7 radon==5.1.0 requests==2.27.1 restructuredtext-lint==1.3.2 @@ -194,7 +194,7 @@ tqdm==4.62.3 traitlets==5.1.1 types-docutils==0.17.6 types-pkg-resources==0.1.3 -types-requests==2.27.10 +types-requests==2.27.11 types-setuptools==57.4.9 types-urllib3==1.26.9 typing-extensions==4.1.1 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index 2a1478585..01e6961c0 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -55,7 +55,7 @@ gprof2dot==2021.2.21 graphviz==0.19.1 greenlet==1.1.2 hepunits==2.2.0 -identify==2.4.10 +identify==2.4.11 idna==3.3 imagesize==1.3.0 importlib-metadata==4.11.1 @@ -74,7 +74,7 @@ jupyter-cache==0.4.3 jupyter-client==7.1.2 jupyter-core==4.9.2 jupyter-server==1.13.5 -jupyter-server-mathjax==0.2.4 +jupyter-server-mathjax==0.2.5 jupyter-sphinx==0.3.2 jupyterlab==3.2.9 jupyterlab-code-formatter==1.4.10 @@ -152,7 +152,7 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.6 +qrules==0.9.7 radon==5.1.0 requests==2.27.1 restructuredtext-lint==1.3.2 @@ -193,7 +193,7 @@ tqdm==4.62.3 traitlets==5.1.1 types-docutils==0.17.6 types-pkg-resources==0.1.3 -types-requests==2.27.10 +types-requests==2.27.11 types-setuptools==57.4.9 types-urllib3==1.26.9 typing-extensions==4.1.1 diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 39d3eac60..f3e0ff09b 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -48,6 +48,7 @@ jobs: name: Unit tests runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: - macos-11 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66e9dfe62..a7c531f0b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -146,7 +146,7 @@ repos: - id: pydocstyle - repo: https://github.com/ComPWA/mirrors-pyright - rev: v1.1.222 + rev: v1.1.223 hooks: - id: pyright diff --git a/docs/conf.py b/docs/conf.py index 047716d24..8a87a2208 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -145,7 +145,10 @@ def fetch_logo(url: str, output_path: str) -> None: "evaluate", "is_commutative", "is_extended_real", + "items", + "keys", "precedence", + "values", ] ), "members": True, @@ -154,7 +157,6 @@ def fetch_logo(url: str, output_path: str) -> None: "special-members": ", ".join( [ "__call__", - "__getitem__", ] ), } diff --git a/setup.cfg b/setup.cfg index a64a06846..e63d2ee9f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,7 @@ setup_requires = install_requires = attrs >=20.1.0 # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen qrules ==0.9.* + singledispatchmethod; python_version <"3.8.0" sympy >=1.8 # module sympy.printing.numpy typing-extensions; python_version <"3.8.0" packages = find: diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index 9b4736c64..4e8c39efb 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -9,10 +9,12 @@ import collections import logging import operator +import sys from collections import OrderedDict, abc from difflib import get_close_matches from functools import reduce from typing import ( + Any, DefaultDict, Dict, ItemsView, @@ -30,16 +32,18 @@ ) import sympy as sp -from attrs import field, frozen +from attrs import define, field, frozen from attrs.validators import instance_of from qrules.combinatorics import ( perform_external_edge_identical_particle_combinatorics, ) +from qrules.particle import Particle from qrules.transition import ReactionInfo, StateTransition from ampform.dynamics.builder import ( ResonanceDynamicsBuilder, TwoBodyKinematicVariableSet, + create_non_dynamic, ) from ampform.kinematics import HelicityAdapter, get_invariant_mass_label @@ -52,6 +56,11 @@ natural_sorting, ) +if sys.version_info >= (3, 8): + from functools import singledispatchmethod +else: + from singledispatchmethod import singledispatchmethod + ParameterValue = Union[float, complex, int] """Allowed value types for parameters.""" @@ -97,6 +106,8 @@ class ParameterValues(abc.Mapping): >>> parameters[2] = 3.14 >>> parameters[c] 3.14 + + .. automethod:: __getitem__ """ def __init__(self, mapping: Mapping[sp.Symbol, ParameterValue]) -> None: @@ -237,6 +248,104 @@ def sum_components( # noqa: R701 ) +@define +class _HelicityModelIngredients: + parameter_defaults: Dict[sp.Symbol, ParameterValue] = field(factory=dict) + components: Dict[str, sp.Expr] = field(factory=dict) + kinematic_variables: Dict[sp.Symbol, sp.Expr] = field(factory=dict) + + def reset(self) -> None: + self.parameter_defaults = {} + self.components = {} + self.kinematic_variables = {} + + +class DynamicsSelector(abc.Mapping): + """Configure which `.ResonanceDynamicsBuilder` to use for each node.""" + + def __init__( + self, transitions: Union[ReactionInfo, Iterable[StateTransition]] + ) -> None: + if isinstance(transitions, ReactionInfo): + transitions = transitions.transitions + self.__choices: Dict[TwoBodyDecay, ResonanceDynamicsBuilder] = {} + for transition in transitions: + for node_id in transition.topology.nodes: + decay = TwoBodyDecay.from_transition(transition, node_id) + self.__choices[decay] = create_non_dynamic + + @singledispatchmethod + def assign( + self, selection: Any, builder: ResonanceDynamicsBuilder + ) -> None: + """Assign a `.ResonanceDynamicsBuilder` to a selection of nodes. + + Currently, the following types of selections are implements: + + - `str`: Select transition nodes by the name of the + `~.TwoBodyDecay.parent` `~qrules.particle.Particle`. + - `.TwoBodyDecay` or `tuple` of a `~qrules.transition.StateTransition` + with a node ID: set dynamics for one specific transition node. + """ + raise NotImplementedError( + "Cannot set dynamics builder for selection type" + f" {type(selection).__name__}" + ) + + @assign.register(TwoBodyDecay) + def _( + self, decay: TwoBodyDecay, builder: ResonanceDynamicsBuilder + ) -> None: + self.__choices[decay] = builder + + @assign.register(tuple) + def _( + self, + transition_node: Tuple[StateTransition, int], + builder: ResonanceDynamicsBuilder, + ) -> None: + decay = TwoBodyDecay.create(transition_node) + return self.assign(decay, builder) + + @assign.register(str) + def _(self, particle_name: str, builder: ResonanceDynamicsBuilder) -> None: + found_particle = False + for decay in self.__choices: + decaying_particle = decay.parent.particle + if decaying_particle.name == particle_name: + self.__choices[decay] = builder + found_particle = True + if not found_particle: + logging.warning( + f'Model contains no resonance with name "{particle_name}"' + ) + + @assign.register(Particle) + def _(self, particle: Particle, builder: ResonanceDynamicsBuilder) -> None: + return self.assign(particle.name, builder) + + def __getitem__( + self, __k: Union[TwoBodyDecay, Tuple[StateTransition, int]] + ) -> ResonanceDynamicsBuilder: + __k = TwoBodyDecay.create(__k) + return self.__choices[__k] + + def __len__(self) -> int: + return len(self.__choices) + + def __iter__(self) -> Iterator[TwoBodyDecay]: + return iter(self.__choices) + + def items(self) -> ItemsView[TwoBodyDecay, ResonanceDynamicsBuilder]: + return self.__choices.items() + + def keys(self) -> KeysView[TwoBodyDecay]: + return self.__choices.keys() + + def values(self) -> ValuesView[ResonanceDynamicsBuilder]: + return self.__choices.values() + + class HelicityAmplitudeBuilder: # pylint: disable=too-many-instance-attributes r"""Amplitude model generator for the helicity formalism. @@ -265,19 +374,15 @@ def __init__( stable_final_state_ids: Optional[Iterable[int]] = None, scalar_initial_state_mass: bool = False, ) -> None: - self._name_generator = HelicityAmplitudeNameGenerator() - self.__reaction = reaction - self.__parameter_defaults: Dict[sp.Symbol, ParameterValue] = {} - self.__components: Dict[str, sp.Expr] = {} - self.__dynamics_choices: Dict[ - TwoBodyDecay, ResonanceDynamicsBuilder - ] = {} - if len(reaction.transitions) < 1: raise ValueError( f"At least one {StateTransition.__name__} required to" " genenerate an amplitude model!" ) + self._name_generator = HelicityAmplitudeNameGenerator() + self.__reaction = reaction + self.__ingredients = _HelicityModelIngredients() + self.__dynamics_choices = DynamicsSelector(reaction) self.__adapter = HelicityAdapter(reaction) self.stable_final_state_ids = stable_final_state_ids # type: ignore[assignment] self.scalar_initial_state_mass = scalar_initial_state_mass # type: ignore[assignment] @@ -289,6 +394,10 @@ def adapter(self) -> HelicityAdapter: """Converter for computing kinematic variables from four-momenta.""" return self.__adapter + @property + def dynamics_choices(self) -> DynamicsSelector: + return self.__dynamics_choices + @property def stable_final_state_ids(self) -> Optional[Set[int]]: # noqa: D403 @@ -331,22 +440,10 @@ def scalar_initial_state_mass(self, value: bool) -> None: def set_dynamics( self, particle_name: str, dynamics_builder: ResonanceDynamicsBuilder ) -> None: - found_particle = False - for transition in self.__reaction.transitions: - for node_id in transition.topology.nodes: - decay = TwoBodyDecay.from_transition(transition, node_id) - decaying_particle = decay.parent.particle - if decaying_particle.name == particle_name: - self.__dynamics_choices[decay] = dynamics_builder - found_particle = True - if not found_particle: - logging.warning( - f'Model contains no resonance with name "{particle_name}"' - ) + self.__dynamics_choices.assign(particle_name, dynamics_builder) def formulate(self) -> HelicityModel: - self.__components = {} - self.__parameter_defaults = {} + self.__ingredients.reset() top_expression = self.__formulate_top_expression() kinematic_variables = { sp.Symbol(var_name, real=True): expr @@ -354,21 +451,21 @@ def formulate(self) -> HelicityModel: } if self.stable_final_state_ids is not None: for state_id in self.stable_final_state_ids: - mass_symbol = sp.Symbol(f"m_{state_id}", real=True) + symbol = sp.Symbol(f"m_{state_id}", real=True) particle = self.__reaction.final_state[state_id] - self.__parameter_defaults[mass_symbol] = particle.mass - del kinematic_variables[mass_symbol] + self.__ingredients.parameter_defaults[symbol] = particle.mass + del kinematic_variables[symbol] if self.scalar_initial_state_mass: subscript = "".join(map(str, sorted(self.__reaction.final_state))) - mass_symbol = sp.Symbol(f"m_{subscript}", real=True) + symbol = sp.Symbol(f"m_{subscript}", real=True) particle = self.__reaction.initial_state[-1] - self.__parameter_defaults[mass_symbol] = particle.mass - del kinematic_variables[mass_symbol] + self.__ingredients.parameter_defaults[symbol] = particle.mass + del kinematic_variables[symbol] return HelicityModel( expression=top_expression, - components=self.__components, - parameter_defaults=self.__parameter_defaults, + components=self.__ingredients.components, + parameter_defaults=self.__ingredients.parameter_defaults, kinematic_variables=kinematic_variables, reaction_info=self.__reaction, ) @@ -407,9 +504,10 @@ def __formulate_coherent_intensity( expression = self.__formulate_sequential_decay(transition) sequential_expressions.append(expression) amplitude_sum = sum(sequential_expressions) - coherent_intensity = abs(amplitude_sum) ** 2 - self.__components[Rf"I_{{{graph_group_label}}}"] = coherent_intensity - return coherent_intensity + expression = abs(amplitude_sum) ** 2 + component_name = f"I_{{{graph_group_label}}}" + self.__ingredients.components[component_name] = expression + return expression def __formulate_sequential_decay( self, transition: StateTransition @@ -425,9 +523,8 @@ def __formulate_sequential_decay( expression = coefficient * sequential_amplitudes if prefactor is not None: expression = prefactor * expression - self.__components[ - f"A_{{{self._name_generator.generate_amplitude_name(transition)}}}" - ] = expression + subscript = self._name_generator.generate_amplitude_name(transition) + self.__ingredients.components[f"A_{{{subscript}}}"] = expression return expression def _formulate_partial_decay( @@ -448,15 +545,15 @@ def __formulate_dynamics( variable_set = _generate_kinematic_variable_set(transition, node_id) expression, parameters = builder(decay.parent.particle, variable_set) for par, value in parameters.items(): - if par in self.__parameter_defaults: - previous_value = self.__parameter_defaults[par] + if par in self.__ingredients.parameter_defaults: + previous_value = self.__ingredients.parameter_defaults[par] if value != previous_value: logging.warning( f'New default value {value} for parameter "{par.name}"' " is inconsistent with existing value" f" {previous_value}" ) - self.__parameter_defaults[par] = value + self.__ingredients.parameter_defaults[par] = value return expression @@ -472,9 +569,10 @@ def __generate_amplitude_coefficient( suffix = self._name_generator.generate_sequential_amplitude_suffix( transition ) - coefficient_symbol = sp.Symbol(f"C_{{{suffix}}}") - self.__parameter_defaults[coefficient_symbol] = complex(1, 0) - return coefficient_symbol + symbol = sp.Symbol(f"C_{{{suffix}}}") + value = complex(1, 0) + self.__ingredients.parameter_defaults[symbol] = value + return symbol def __generate_amplitude_prefactor( self, transition: StateTransition diff --git a/src/ampform/helicity/decay.py b/src/ampform/helicity/decay.py index 106fac668..308b62417 100644 --- a/src/ampform/helicity/decay.py +++ b/src/ampform/helicity/decay.py @@ -1,6 +1,7 @@ """Extract two-body decay info from a `~qrules.transition.StateTransition`.""" -from typing import Iterable, List, Tuple +from functools import singledispatch +from typing import Any, Iterable, List, Tuple from attrs import frozen from qrules.quantum_numbers import InteractionProperties @@ -47,6 +48,16 @@ class TwoBodyDecay: children: Tuple[StateWithID, StateWithID] interaction: InteractionProperties + @staticmethod + def create(obj: Any) -> "TwoBodyDecay": + """Create a `TwoBodyDecay` instance from an arbitrary object. + + More implementations of :meth:`create` can be implemented with + :func:`@ampform.helicity.decay._create_two_body_decay.register(TYPE) + `. + """ + return _create_two_body_decay(obj) + @classmethod def from_transition( cls, transition: StateTransition, node_id: int @@ -80,6 +91,28 @@ def from_transition( ) +@singledispatch +def _create_two_body_decay(obj: Any) -> TwoBodyDecay: + raise NotImplementedError( + f"Cannot create a {TwoBodyDecay.__name__} from a {type(obj).__name__}" + ) + + +@_create_two_body_decay.register(TwoBodyDecay) +def _(obj: TwoBodyDecay) -> TwoBodyDecay: + return obj + + +@_create_two_body_decay.register(tuple) +def _(obj: tuple) -> TwoBodyDecay: + if len(obj) == 2: + if isinstance(obj[0], StateTransition) and isinstance(obj[1], int): + return TwoBodyDecay.from_transition(*obj) + raise NotImplementedError( + f"Cannot create a {TwoBodyDecay.__name__} from {obj}" + ) + + def get_helicity_info( transition: StateTransition, node_id: int ) -> Tuple[State, Tuple[State, State]]: diff --git a/src/symplot/__init__.py b/src/symplot/__init__.py index a1309d7b2..c98f81d17 100644 --- a/src/symplot/__init__.py +++ b/src/symplot/__init__.py @@ -63,6 +63,8 @@ class SliderKwargs(abc.Mapping): Sliders can be defined in :func:`~mpl_interactions.pyplot.interactive_plot` through :term:`kwargs `. This wrapper class can be used for that. + + .. automethod:: __getitem__ """ def __init__(